From bc6dbfeb991141340ea3234afabbdee33287027d Mon Sep 17 00:00:00 2001 From: Gao Enhao Date: Sat, 14 Oct 2023 12:32:22 +0800 Subject: [PATCH] [MNT] resolve all comments in basic_nn.py --- abl/bridge/simple_bridge.py | 4 +- abl/learning/basic_nn.py | 274 ++++++++++++++++-------------------- 2 files changed, 126 insertions(+), 152 deletions(-) diff --git a/abl/bridge/simple_bridge.py b/abl/bridge/simple_bridge.py index 7286d42..3aecffd 100644 --- a/abl/bridge/simple_bridge.py +++ b/abl/bridge/simple_bridge.py @@ -74,10 +74,10 @@ class SimpleBridge(BaseBridge): pred_prob, pred_pseudo_label, Y ) abduced_label = self.pseudo_label_to_idx(abduced_pseudo_label) - min_loss = self.model.train(X, abduced_label) + loss = self.model.train(X, abduced_label) print_log( - f"Epoch(train) [{epoch + 1}] [{(seg_idx + 1):3}/{len(data_loader)}] minimal_loss is {min_loss:.5f}", + f"Epoch(train) [{epoch + 1}] [{(seg_idx + 1):3}/{len(data_loader)}] model loss is {loss:.5f}", logger="current", ) diff --git a/abl/learning/basic_nn.py b/abl/learning/basic_nn.py index 28cf20b..b9b0b36 100644 --- a/abl/learning/basic_nn.py +++ b/abl/learning/basic_nn.py @@ -10,10 +10,6 @@ # # ================================================================# -import sys - -sys.path.append("..") - import torch import numpy from torch.utils.data import DataLoader @@ -21,12 +17,12 @@ from ..utils.logger import print_log from ..dataset import ClassificationDataset import os -from typing import List, Any, T, Optional, Callable +from typing import List, Any, T, Optional, Callable, Tuple class BasicNN: """ - Wrap NN models into the form of an sklearn estimator + Wrap NN models into the form of an sklearn estimator. Parameters ---------- @@ -34,83 +30,35 @@ class BasicNN: The PyTorch model to be trained or used for prediction. criterion : torch.nn.Module The loss function used for training. - optimizer : torch.nn.Module + optimizer : torch.optim.Optimizer The optimizer used for training. device : torch.device, optional - The device on which the model will be trained or used for prediction, by default torch.decive("cpu"). + The device on which the model will be trained or used for prediction, by default torch.device("cpu"). batch_size : int, optional - The batch size used for training, by default 1. + The batch size used for training, by default 32. num_epochs : int, optional The number of epochs used for training, by default 1. stop_loss : Optional[float], optional The loss value at which to stop training, by default 0.01. - num_workers : int, optional + num_workers : int The number of workers used for loading data, by default 0. save_interval : Optional[int], optional The interval at which to save the model during training, by default None. save_dir : Optional[str], optional The directory in which to save the model during training, by default None. transform : Callable[..., Any], optional - A function/transform that takes in an object and returns a transformed version. Defaults to None. + A function/transform that takes in an object and returns a transformed version, by default None. collate_fn : Callable[[List[T]], Any], optional The function used to collate data, by default None. - - Attributes - ---------- - model : torch.nn.Module - The PyTorch model to be trained or used for prediction. - batch_size : int - The batch size used for training. - num_epochs : int - The number of epochs used for training. - stop_loss : Optional[float] - The loss value at which to stop training. - num_workers : int - The number of workers used for loading data. - criterion : torch.nn.Module - The loss function used for training. - optimizer : torch.nn.Module - The optimizer used for training. - transform : Callable[..., Any] - The transformation function used for data augmentation. - device : torch.device - The device on which the model will be trained or used for prediction. - save_interval : Optional[int] - The interval at which to save the model during training. - save_dir : Optional[str] - The directory in which to save the model during training. - collate_fn : Callable[[List[T]], Any] - The function used to collate data. - - Methods - ------- - fit(data_loader=None, X=None, y=None) - Train the model. - train_epoch(data_loader) - Train the model for one epoch. - predict(data_loader=None, X=None, print_prefix="") - Predict the class of the input data. - predict_proba(data_loader=None, X=None, print_prefix="") - Predict the probability of each class for the input data. - val(data_loader=None, X=None, y=None, print_prefix="") - Validate the model. - score(data_loader=None, X=None, y=None, print_prefix="") - Score the model. - _data_loader(X, y=None) - Generate the data_loader. - save(epoch_id, save_dir="") - Save the model. - load(epoch_id, load_dir="") - Load the model. """ def __init__( self, model: torch.nn.Module, criterion: torch.nn.Module, - optimizer: torch.nn.Module, + optimizer: torch.optim.Optimizer, device: torch.device = torch.device("cpu"), - batch_size: int = 1, + batch_size: int = 32, num_epochs: int = 1, stop_loss: Optional[float] = 0.01, num_workers: int = 0, @@ -118,38 +66,46 @@ class BasicNN: save_dir: Optional[str] = None, transform: Callable[..., Any] = None, collate_fn: Callable[[List[T]], Any] = None, - ): + ) -> None: self.model = model.to(device) - + self.criterion = criterion + self.optimizer = optimizer + self.device = device self.batch_size = batch_size self.num_epochs = num_epochs self.stop_loss = stop_loss self.num_workers = num_workers - - self.criterion = criterion - self.optimizer = optimizer - self.transform = transform - self.device = device - self.save_interval = save_interval self.save_dir = save_dir + self.transform = transform self.collate_fn = collate_fn - def _fit(self, data_loader, n_epoch, stop_loss): - min_loss = 1e10 - for epoch in range(n_epoch): + def _fit(self, data_loader) -> float: + """ + Internal method to fit the model on data for n epochs, with early stopping. + + Parameters + ---------- + data_loader : DataLoader + Data loader providing training samples. + + Returns + ------- + float + The loss value of the trained model. + """ + loss_value = 1e9 + for epoch in range(self.num_epochs): loss_value = self.train_epoch(data_loader) - if min_loss < 0 or loss_value < min_loss: - min_loss = loss_value if self.save_interval is not None and (epoch + 1) % self.save_interval == 0: if self.save_dir is None: raise ValueError( - "save_dir should not be None if save_interval is not None" + "save_dir should not be None if save_interval is not None." ) - self.save(epoch + 1, self.save_dir) - if stop_loss is not None and loss_value < stop_loss: + self.save(epoch + 1) + if self.stop_loss is not None and loss_value < self.stop_loss: break - return min_loss + return loss_value def fit( self, data_loader: DataLoader = None, X: List[Any] = None, y: List[int] = None @@ -160,11 +116,11 @@ class BasicNN: Parameters ---------- data_loader : DataLoader, optional - The data loader used for training, by default None + The data loader used for training, by default None. X : List[Any], optional - The input data, by default None + The input data, by default None. y : List[int], optional - The target data, by default None + The target data, by default None. Returns ------- @@ -172,10 +128,13 @@ class BasicNN: The loss value of the trained model. """ if data_loader is None: - data_loader = self._data_loader(X, y) - return self._fit(data_loader, self.num_epochs, self.stop_loss) + if X is None: + raise ValueError("data_loader and X can not be None simultaneously.") + else: + data_loader = self._data_loader(X, y) + return self._fit(data_loader) - def train_epoch(self, data_loader: DataLoader): + def train_epoch(self, data_loader: DataLoader) -> float: """ Train the model for one epoch. @@ -211,7 +170,20 @@ class BasicNN: return total_loss / total_num - def _predict(self, data_loader): + def _predict(self, data_loader) -> torch.Tensor: + """ + Internal method to predict the outputs given a DataLoader. + + Parameters + ---------- + data_loader : DataLoader + The DataLoader providing input samples. + + Returns + ------- + torch.Tensor + Raw output from the model. + """ model = self.model device = self.device @@ -227,10 +199,7 @@ class BasicNN: return torch.cat(results, axis=0) def predict( - self, - data_loader: DataLoader = None, - X: List[Any] = None, - print_prefix: str = "", + self, data_loader: DataLoader = None, X: List[Any] = None ) -> numpy.ndarray: """ Predict the class of the input data. @@ -238,11 +207,9 @@ class BasicNN: Parameters ---------- data_loader : DataLoader, optional - The data loader used for prediction, by default None + The data loader used for prediction, by default None. X : List[Any], optional - The input data, by default None - print_prefix : str, optional - The prefix used for printing, by default "" + The input data, by default None. Returns ------- @@ -255,10 +222,7 @@ class BasicNN: return self._predict(data_loader).argmax(axis=1).cpu().numpy() def predict_proba( - self, - data_loader: DataLoader = None, - X: List[Any] = None, - print_prefix: str = "", + self, data_loader: DataLoader = None, X: List[Any] = None ) -> numpy.ndarray: """ Predict the probability of each class for the input data. @@ -266,11 +230,9 @@ class BasicNN: Parameters ---------- data_loader : DataLoader, optional - The data loader used for prediction, by default None + The data loader used for prediction, by default None. X : List[Any], optional - The input data, by default None - print_prefix : str, optional - The prefix used for printing, by default "" + The input data, by default None. Returns ------- @@ -282,7 +244,21 @@ class BasicNN: data_loader = self._data_loader(X) return self._predict(data_loader).softmax(axis=1).cpu().numpy() - def _score(self, data_loader): + def _score(self, data_loader) -> Tuple[float, float]: + """ + Internal method to compute loss and accuracy for the data provided through a DataLoader. + + Parameters + ---------- + data_loader : DataLoader + Data loader to use for evaluation. + + Returns + ------- + Tuple[float, float] + mean_loss: float, The mean loss of the model on the provided data. + accuracy: float, The accuracy of the model on the provided data. + """ model = self.model criterion = self.criterion device = self.device @@ -298,9 +274,9 @@ class BasicNN: out = model(data) if len(out.shape) > 1: - correct_num = sum(target == out.argmax(axis=1)).item() + correct_num = (target == out.argmax(axis=1)).sum().item() else: - correct_num = sum(target == (out > 0.5)).item() + correct_num = (target == (out > 0.5)).sum().item() loss = criterion(out, target) total_loss += loss.item() * data.size(0) @@ -313,11 +289,7 @@ class BasicNN: return mean_loss, accuracy def score( - self, - data_loader: DataLoader = None, - X: List[Any] = None, - y: List[int] = None, - print_prefix: str = "", + self, data_loader: DataLoader = None, X: List[Any] = None, y: List[int] = None ) -> float: """ Validate the model. @@ -325,25 +297,25 @@ class BasicNN: Parameters ---------- data_loader : DataLoader, optional - The data loader used for scoring, by default None + The data loader used for scoring, by default None. X : List[Any], optional - The input data, by default None + The input data, by default None. y : List[int], optional - The target data, by default None - print_prefix : str, optional - The prefix used for printing, by default "" + The target data, by default None. Returns ------- float The accuracy of the model. """ - print_log(f"Start machine learning model validation", logger="current") + print_log("Start machine learning model validation", logger="current") if data_loader is None: data_loader = self._data_loader(X, y) mean_loss, accuracy = self._score(data_loader) - print_log(f"{print_prefix} mean loss: {mean_loss:.3f}, accuray: {accuracy:.3f}", logger="current") + print_log( + f"mean loss: {mean_loss:.3f}, accuray: {accuracy:.3f}", logger="current" + ) return accuracy def _data_loader( @@ -352,38 +324,39 @@ class BasicNN: y: List[int] = None, ) -> DataLoader: """ - Generate data_loader for user provided data. + Generate a DataLoader for user-provided input and target data. Parameters ---------- X : List[Any] - The input data. + Input samples. y : List[int], optional - The target data, by default None + Target labels. If None, dummy labels are created, by default None. Returns ------- DataLoader - The data loader. + A DataLoader providing batches of (X, y) pairs. """ - collate_fn = self.collate_fn - transform = self.transform + if X is None: + raise ValueError("X should not be None.") if y is None: y = [0] * len(X) - dataset = ClassificationDataset(X, y, transform=transform) - sampler = None + if not (len(y) == len(X)): + raise ValueError("X and y should have equal length.") + + dataset = ClassificationDataset(X, y, transform=self.transform) data_loader = DataLoader( dataset, batch_size=self.batch_size, - shuffle=False, - sampler=sampler, + shuffle=True, num_workers=int(self.num_workers), - collate_fn=collate_fn, + collate_fn=self.collate_fn, ) return data_loader - def save(self, epoch_id: int = 0, save_dir: str = None, save_path: str = None): + def save(self, epoch_id: int = 0, save_path: str = None) -> None: """ Save the model and the optimizer. @@ -391,15 +364,20 @@ class BasicNN: ---------- epoch_id : int The epoch id. - save_dir : str, optional - The directory to save the model, by default "" + save_path : str, optional + The path to save the model, by default None. """ - if save_dir and (not os.path.exists(save_dir)): - os.makedirs(save_dir) - print_log(f"Checkpoints will be saved to {save_dir}", logger="current") - + if self.save_dir is None and save_path is None: + raise ValueError( + "'save_dir' and 'save_path' should not be None simultaneously." + ) + if save_path is None: - save_path = os.path.join(save_dir, str(epoch_id) + ".pth") + save_path = os.path.join( + self.save_dir, f"model_checkpoint_epoch_{epoch_id}.pth" + ) + if not os.path.exists(self.save_dir): + os.makedirs(self.save_dir) print_log(f"Checkpoints will be saved to {save_path}", logger="current") @@ -410,29 +388,25 @@ class BasicNN: torch.save(save_parma_dic, save_path) - def load(self, epoch_id: int = 0, load_dir: str = "", load_path: str = None): + def load(self, load_path: str = "") -> None: """ Load the model and the optimizer. Parameters ---------- - epoch_id : int - The epoch id. - load_dir : str, optional - The directory to load the model, by default "" + load_path : str + The directory to load the model, by default "". """ - if load_path is not None: - print_log(f"Loads checkpoint by local backend from path: {load_path}", logger="current") - else: - print_log(f"Loads checkpoint by local backend from dir: {load_dir}", logger="current") - load_path = os.path.join(load_dir, str(epoch_id) + ".pth") - + if load_path is None: + raise ValueError("Load path should not be None.") + + print_log( + f"Loads checkpoint by local backend from path: {load_path}", + logger="current", + ) + param_dic = torch.load(load_path) self.model.load_state_dict(param_dic["model"]) if "optimizer" in param_dic.keys(): self.optimizer.load_state_dict(param_dic["optimizer"]) - - -if __name__ == "__main__": - pass