From 2e73cfd9b0e63b52149a0126e0dd4a5b33392e41 Mon Sep 17 00:00:00 2001 From: Gao Enhao Date: Fri, 7 Apr 2023 14:06:09 +0800 Subject: [PATCH] [MNT] rename modules in abl.learning --- abl/learning/abl_model.py | 137 +++++ abl/learning/basic_nn.py | 560 +++++++++++++++++++++ examples/hed/hed_example.ipynb | 21 +- examples/hwf/hwf_example.ipynb | 25 +- examples/mnist_add/mnist_add_example.ipynb | 25 +- 5 files changed, 740 insertions(+), 28 deletions(-) create mode 100644 abl/learning/abl_model.py create mode 100644 abl/learning/basic_nn.py diff --git a/abl/learning/abl_model.py b/abl/learning/abl_model.py new file mode 100644 index 0000000..2f0feeb --- /dev/null +++ b/abl/learning/abl_model.py @@ -0,0 +1,137 @@ +# coding: utf-8 +# ================================================================# +# Copyright (C) 2020 Freecss All rights reserved. +# +# File Name :models.py +# Author :freecss +# Email :karlfreecss@gmail.com +# Created Date :2020/04/02 +# Description : +# +# ================================================================# +from itertools import chain +from typing import List, Any + + +def get_part_data(X, i): + return list(map(lambda x: x[i], X)) + + +def merge_data(X): + ret_mark = list(map(lambda x: len(x), X)) + ret_X = list(chain(*X)) + return ret_X, ret_mark + + +def reshape_data(Y, marks): + begin_mark = 0 + ret_Y = [] + for mark in marks: + end_mark = begin_mark + mark + ret_Y.append(Y[begin_mark:end_mark]) + begin_mark = end_mark + return ret_Y + + +class ABLModel: + """ + Serialize data and provide a unified interface for different machine learning models. + + Parameters + ---------- + base_model : Machine Learning Model + The base model to use for training and prediction. + pseudo_label_list : List[Any] + A list of pseudo labels to use for training. + + Attributes + ---------- + cls_list : List[Any] + A list of classifiers. + pseudo_label_list : List[Any] + A list of pseudo labels to use for training. + mapping : dict + A dictionary mapping pseudo labels to integers. + remapping : dict + A dictionary mapping integers to pseudo labels. + + Methods + ------- + predict(X: List[List[Any]]) -> dict + Predict the class labels and probabilities for the given data. + valid(X: List[List[Any]], Y: List[Any]) -> float + Calculate the accuracy score for the given data. + train(X: List[List[Any]], Y: List[Any]) + Train the model on the given data. + """ + def __init__(self, base_model, pseudo_label_list: List[Any]): + self.cls_list = [] + self.cls_list.append(base_model) + + self.pseudo_label_list = pseudo_label_list + self.mapping = dict(zip(pseudo_label_list, list(range(len(pseudo_label_list))))) + self.remapping = dict( + zip(list(range(len(pseudo_label_list))), pseudo_label_list) + ) + + def predict(self, X: List[List[Any]]) -> dict: + """ + Predict the class labels and probabilities for the given data. + + Parameters + ---------- + X : List[List[Any]] + The data to predict on. + + Returns + ------- + dict + A dictionary containing the predicted class labels and probabilities. + """ + data_X, marks = merge_data(X) + prob = self.cls_list[0].predict_proba(X=data_X) + _cls = prob.argmax(axis=1) + cls = list(map(lambda x: self.remapping[x], _cls)) + + prob = reshape_data(prob, marks) + cls = reshape_data(cls, marks) + + return {"cls": cls, "prob": prob} + + def valid(self, X: List[List[Any]], Y: List[Any]) -> float: + """ + Calculate the accuracy for the given data. + + Parameters + ---------- + X : List[List[Any]] + The data to calculate the accuracy on. + Y : List[Any] + The true class labels for the given data. + + Returns + ------- + float + The accuracy score for the given data. + """ + data_X, _ = merge_data(X) + _data_Y, _ = merge_data(Y) + data_Y = list(map(lambda y: self.mapping[y], _data_Y)) + score = self.cls_list[0].score(X=data_X, y=data_Y) + return score + + def train(self, X: List[List[Any]], Y: List[Any]): + """ + Train the model on the given data. + + Parameters + ---------- + X : List[List[Any]] + The data to train on. + Y : List[Any] + The true class labels for the given data. + """ + data_X, _ = merge_data(X) + _data_Y, _ = merge_data(Y) + data_Y = list(map(lambda y: self.mapping[y], _data_Y)) + self.cls_list[0].fit(X=data_X, y=data_Y) diff --git a/abl/learning/basic_nn.py b/abl/learning/basic_nn.py new file mode 100644 index 0000000..ed0b28c --- /dev/null +++ b/abl/learning/basic_nn.py @@ -0,0 +1,560 @@ +# coding: utf-8 +# ================================================================# +# Copyright (C) 2020 Freecss All rights reserved. +# +# File Name :basic_model.py +# Author :freecss +# Email :karlfreecss@gmail.com +# Created Date :2020/11/21 +# Description : +# +# ================================================================# + +import sys + +sys.path.append("..") + +import torch +import numpy +from torch.utils.data import Dataset, DataLoader + +import os +from multiprocessing import Pool +from typing import List, Any, T, Tuple, Optional, Callable + + +class BasicDataset(Dataset): + def __init__(self, X: List[Any], Y: List[Any]): + """Initialize a basic dataset. + + Parameters + ---------- + X : List[Any] + A list of objects representing the input data. + Y : List[Any] + A list of objects representing the output data. + """ + self.X = X + self.Y = Y + + def __len__(self): + """Return the length of the dataset. + + Returns + ------- + int + The length of the dataset. + """ + return len(self.X) + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + """Get an item from the dataset. + + Parameters + ---------- + index : int + The index of the item to retrieve. + + Returns + ------- + Tuple[Any, Any] + A tuple containing the input and output data at the specified index. + """ + if index >= len(self): + raise ValueError("index range error") + + img = self.X[index] + label = self.Y[index] + + return (img, label) + + +class XYDataset(Dataset): + def __init__(self, X: List[Any], Y: List[int], transform: Callable[..., Any] = None): + """ + Initialize the dataset used for classification task. + + Parameters + ---------- + X : List[Any] + The input data. + Y : List[int] + The target data. + transform : Callable[..., Any], optional + A function/transform that takes in an object and returns a transformed version. Defaults to None. + """ + self.X = X + self.Y = torch.LongTensor(Y) + + self.n_sample = len(X) + self.transform = transform + + def __len__(self) -> int: + """ + Return the length of the dataset. + + Returns + ------- + int + The length of the dataset. + """ + return len(self.X) + + def __getitem__(self, index: int) -> Tuple[Any, torch.Tensor]: + """ + Get the item at the given index. + + Parameters + ---------- + index : int + The index of the item to get. + + Returns + ------- + Tuple[Any, torch.Tensor] + A tuple containing the object and its label. + """ + if index >= len(self): + raise ValueError("index range error") + + img = self.X[index] + if self.transform is not None: + img = self.transform(img) + + label = self.Y[index] + + return (img, label) + + +class FakeRecorder: + def __init__(self): + pass + + def print(self, *x): + pass + + +class BasicNN: + """ + Wrap NN models into the form of an sklearn estimator + + Parameters + ---------- + model : torch.nn.Module + The PyTorch model to be trained or used for prediction. + criterion : torch.nn.Module + The loss function used for training. + optimizer : torch.nn.Module + The optimizer used for training. + device : torch.device, optional + The device on which the model will be trained or used for prediction, by default torch.decive("cpu"). + batch_size : int, optional + The batch size used for training, by default 1. + num_epochs : int, optional + The number of epochs used for training, by default 1. + stop_loss : Optional[float], optional + The loss value at which to stop training, by default 0.01. + num_workers : int, optional + The number of workers used for loading data, by default 0. + save_interval : Optional[int], optional + The interval at which to save the model during training, by default None. + save_dir : Optional[str], optional + The directory in which to save the model during training, by default None. + transform : Callable[..., Any], optional + A function/transform that takes in an object and returns a transformed version. Defaults to None. + collate_fn : Callable[[List[T]], Any], optional + The function used to collate data, by default None. + recorder : Any, optional + The recorder used to record training progress, by default None. + + Attributes + ---------- + model : torch.nn.Module + The PyTorch model to be trained or used for prediction. + batch_size : int + The batch size used for training. + num_epochs : int + The number of epochs used for training. + stop_loss : Optional[float] + The loss value at which to stop training. + num_workers : int + The number of workers used for loading data. + criterion : torch.nn.Module + The loss function used for training. + optimizer : torch.nn.Module + The optimizer used for training. + transform : Callable[..., Any] + The transformation function used for data augmentation. + device : torch.device + The device on which the model will be trained or used for prediction. + recorder : Any + The recorder used to record training progress. + save_interval : Optional[int] + The interval at which to save the model during training. + save_dir : Optional[str] + The directory in which to save the model during training. + collate_fn : Callable[[List[T]], Any] + The function used to collate data. + + Methods + ------- + fit(data_loader=None, X=None, y=None) + Train the model. + train_epoch(data_loader) + Train the model for one epoch. + predict(data_loader=None, X=None, print_prefix="") + Predict the class of the input data. + predict_proba(data_loader=None, X=None, print_prefix="") + Predict the probability of each class for the input data. + val(data_loader=None, X=None, y=None, print_prefix="") + Validate the model. + score(data_loader=None, X=None, y=None, print_prefix="") + Score the model. + _data_loader(X, y=None) + Generate the data_loader. + save(epoch_id, save_dir="") + Save the model. + load(epoch_id, load_dir="") + Load the model. + """ + + def __init__( + self, + model: torch.nn.Module, + criterion: torch.nn.Module, + optimizer: torch.nn.Module, + device: torch.device = torch.device("cpu"), + batch_size: int = 1, + num_epochs: int = 1, + stop_loss: Optional[float] = 0.01, + num_workers: int = 0, + save_interval: Optional[int] = None, + save_dir: Optional[str] = None, + transform: Callable[..., Any] = None, + collate_fn: Callable[[List[T]], Any] = None, + recorder=None, + ): + + self.model = model.to(device) + + self.batch_size = batch_size + self.num_epochs = num_epochs + self.stop_loss = stop_loss + self.num_workers = num_workers + + self.criterion = criterion + self.optimizer = optimizer + self.transform = transform + self.device = device + + if recorder is None: + recorder = FakeRecorder() + self.recorder = recorder + + self.save_interval = save_interval + self.save_dir = save_dir + self.collate_fn = collate_fn + + def _fit(self, data_loader, n_epoch, stop_loss): + recorder = self.recorder + recorder.print("model fitting") + + min_loss = 1e10 + for epoch in range(n_epoch): + loss_value = self.train_epoch(data_loader) + recorder.print(f"{epoch}/{n_epoch} model training loss is {loss_value}") + if min_loss < 0 or loss_value < min_loss: + min_loss = loss_value + if self.save_interval is not None and (epoch + 1) % self.save_interval == 0: + if self.save_dir is None: + raise ValueError( + "save_dir should not be None if save_interval is not None" + ) + self.save(epoch + 1, self.save_dir) + if stop_loss is not None and loss_value < stop_loss: + break + recorder.print("Model fitted, minimal loss is ", min_loss) + return loss_value + + def fit( + self, data_loader: DataLoader = None, X: List[Any] = None, y: List[int] = None + ) -> float: + """ + Train the model. + + Parameters + ---------- + data_loader : DataLoader, optional + The data loader used for training, by default None + X : List[Any], optional + The input data, by default None + y : List[int], optional + The target data, by default None + + Returns + ------- + float + The loss value of the trained model. + """ + if data_loader is None: + data_loader = self._data_loader(X, y) + return self._fit(data_loader, self.num_epochs, self.stop_loss) + + def train_epoch(self, data_loader: DataLoader): + """ + Train the model for one epoch. + + Parameters + ---------- + data_loader : DataLoader + The data loader used for training. + + Returns + ------- + float + The loss value of the trained model. + """ + model = self.model + criterion = self.criterion + optimizer = self.optimizer + device = self.device + + model.train() + + total_loss, total_num = 0.0, 0 + for data, target in data_loader: + data, target = data.to(device), target.to(device) + out = model(data) + loss = criterion(out, target) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + total_loss += loss.item() * data.size(0) + total_num += data.size(0) + + return total_loss / total_num + + def _predict(self, data_loader): + model = self.model + device = self.device + + model.eval() + + with torch.no_grad(): + results = [] + for data, _ in data_loader: + data = data.to(device) + out = model(data) + results.append(out) + + return torch.cat(results, axis=0) + + def predict( + self, + data_loader: DataLoader = None, + X: List[Any] = None, + print_prefix: str = "", + ) -> numpy.ndarray: + """ + Predict the class of the input data. + + Parameters + ---------- + data_loader : DataLoader, optional + The data loader used for prediction, by default None + X : List[Any], optional + The input data, by default None + print_prefix : str, optional + The prefix used for printing, by default "" + + Returns + ------- + numpy.ndarray + The predicted class of the input data. + """ + recorder = self.recorder + recorder.print("Start Predict Class ", print_prefix) + + if data_loader is None: + data_loader = self._data_loader(X) + return self._predict(data_loader).argmax(axis=1).cpu().numpy() + + def predict_proba( + self, + data_loader: DataLoader = None, + X: List[Any] = None, + print_prefix: str = "", + ) -> numpy.ndarray: + """ + Predict the probability of each class for the input data. + + Parameters + ---------- + data_loader : DataLoader, optional + The data loader used for prediction, by default None + X : List[Any], optional + The input data, by default None + print_prefix : str, optional + The prefix used for printing, by default "" + + Returns + ------- + numpy.ndarray + The predicted probability of each class for the input data. + """ + recorder = self.recorder + recorder.print("Start Predict Probability ", print_prefix) + + if data_loader is None: + data_loader = self._data_loader(X) + return self._predict(data_loader).softmax(axis=1).cpu().numpy() + + def _score(self, data_loader): + model = self.model + criterion = self.criterion + device = self.device + + model.eval() + + total_correct_num, total_num, total_loss = 0, 0, 0.0 + + with torch.no_grad(): + for data, target in data_loader: + data, target = data.to(device), target.to(device) + + out = model(data) + + if len(out.shape) > 1: + correct_num = sum(target == out.argmax(axis=1)).item() + else: + correct_num = sum(target == (out > 0.5)).item() + loss = criterion(out, target) + total_loss += loss.item() * data.size(0) + + total_correct_num += correct_num + total_num += data.size(0) + + mean_loss = total_loss / total_num + accuracy = total_correct_num / total_num + + return mean_loss, accuracy + + def score( + self, + data_loader: DataLoader = None, + X: List[Any] = None, + y: List[int] = None, + print_prefix: str = "", + ) -> float: + """ + Validate the model. + + Parameters + ---------- + data_loader : DataLoader, optional + The data loader used for scoring, by default None + X : List[Any], optional + The input data, by default None + y : List[int], optional + The target data, by default None + print_prefix : str, optional + The prefix used for printing, by default "" + + Returns + ------- + float + The accuracy of the model. + """ + recorder = self.recorder + recorder.print("Start validation ", print_prefix) + + if data_loader is None: + data_loader = self._data_loader(X, y) + mean_loss, accuracy = self._score(data_loader) + recorder.print( + "[%s] mean loss: %f, accuray: %f" % (print_prefix, mean_loss, accuracy) + ) + return accuracy + + def _data_loader( + self, + X: List[Any], + y: List[int] = None, + ) -> DataLoader: + """ + Generate data_loader for user provided data. + + Parameters + ---------- + X : List[Any] + The input data. + y : List[int], optional + The target data, by default None + + Returns + ------- + DataLoader + The data loader. + """ + collate_fn = self.collate_fn + transform = self.transform + + if y is None: + y = [0] * len(X) + dataset = XYDataset(X, y, transform=transform) + sampler = None + data_loader = DataLoader( + dataset, + batch_size=self.batch_size, + shuffle=False, + sampler=sampler, + num_workers=int(self.num_workers), + collate_fn=collate_fn, + ) + return data_loader + + def save(self, epoch_id: int, save_dir: str = ""): + """ + Save the model and the optimizer. + + Parameters + ---------- + epoch_id : int + The epoch id. + save_dir : str, optional + The directory to save the model, by default "" + """ + recorder = self.recorder + if not os.path.exists(save_dir): + os.makedirs(save_dir) + recorder.print("Saving model and opter") + save_path = os.path.join(save_dir, str(epoch_id) + "_net.pth") + torch.save(self.model.state_dict(), save_path) + + save_path = os.path.join(save_dir, str(epoch_id) + "_opt.pth") + torch.save(self.optimizer.state_dict(), save_path) + + def load(self, epoch_id: int, load_dir: str = ""): + """ + Load the model and the optimizer. + + Parameters + ---------- + epoch_id : int + The epoch id. + load_dir : str, optional + The directory to load the model, by default "" + """ + recorder = self.recorder + recorder.print("Loading model and opter") + load_path = os.path.join(load_dir, str(epoch_id) + "_net.pth") + self.model.load_state_dict(torch.load(load_path)) + + load_path = os.path.join(load_dir, str(epoch_id) + "_opt.pth") + self.optimizer.load_state_dict(torch.load(load_path)) + + +if __name__ == "__main__": + pass diff --git a/examples/hed/hed_example.ipynb b/examples/hed/hed_example.ipynb index 0a126d8..496acaf 100644 --- a/examples/hed/hed_example.ipynb +++ b/examples/hed/hed_example.ipynb @@ -18,8 +18,8 @@ "from abl.abducer.kb import prolog_KB\n", "\n", "from abl.utils.plog import logger\n", - "from abl.models.basic_model import BasicModel\n", - "from abl.models.wabl_models import WABLBasicModel\n", + "from abl.models.basic_nn import BasicNN\n", + "from abl.models.abl_model import ABLModel\n", "from abl.utils.utils import reform_idx\n", "\n", "from models.nn import SymbolNet\n", @@ -172,9 +172,9 @@ "metadata": {}, "outputs": [], "source": [ - "# Initialize BasicModel\n", - "# The function of BasicModel is to wrap NN models into the form of an sklearn estimator\n", - "base_model = BasicModel(\n", + "# Initialize BasicNN\n", + "# The function of BasicNN is to wrap NN models into the form of an sklearn estimator\n", + "base_model = BasicNN(\n", " cls,\n", " criterion,\n", " optimizer,\n", @@ -192,7 +192,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Use WABL model to join two parts" + "### Use ABL model to join two parts" ] }, { @@ -201,7 +201,7 @@ "metadata": {}, "outputs": [], "source": [ - "model = WABLBasicModel(base_model, kb.pseudo_label_list)" + "model = ABLModel(base_model, kb.pseudo_label_list)" ] }, { @@ -262,7 +262,12 @@ "pygments_lexer": "ipython3", "version": "3.8.16" }, - "orig_nbformat": 4 + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "fb6f4ceeabb9a733f366948eb80109f83aedf798cc984df1e68fb411adb27d58" + } + } }, "nbformat": 4, "nbformat_minor": 2 diff --git a/examples/hwf/hwf_example.ipynb b/examples/hwf/hwf_example.ipynb index 80eb05a..f9469ae 100644 --- a/examples/hwf/hwf_example.ipynb +++ b/examples/hwf/hwf_example.ipynb @@ -18,8 +18,8 @@ "from abl.abducer.kb import KBBase\n", "\n", "from abl.utils.plog import logger\n", - "from abl.models.basic_model import BasicModel\n", - "from abl.models.wabl_models import WABLBasicModel\n", + "from abl.models.basic_nn import BasicNN\n", + "from abl.models.abl_model import ABLModel\n", "\n", "from models.nn import SymbolNet\n", "from datasets.get_hwf import get_hwf\n", @@ -111,9 +111,9 @@ "metadata": {}, "outputs": [], "source": [ - "# Initialize BasicModel\n", - "# The function of BasicModel is to wrap NN models into the form of an sklearn estimator\n", - "base_model = BasicModel(\n", + "# Initialize BasicNN\n", + "# The function of BasicNN is to wrap NN models into the form of an sklearn estimator\n", + "base_model = BasicNN(\n", " cls,\n", " criterion,\n", " optimizer,\n", @@ -131,7 +131,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Use WABL model to join two parts" + "### Use ABL model to join two parts" ] }, { @@ -140,10 +140,10 @@ "metadata": {}, "outputs": [], "source": [ - "# Initialize WABL model\n", - "# The main function of the WABL model is to serialize data and \n", + "# Initialize ABL model\n", + "# The main function of the ABL model is to serialize data and \n", "# provide a unified interface for different machine learning models\n", - "model = WABLBasicModel(base_model, kb.pseudo_label_list)" + "model = ABLModel(base_model, kb.pseudo_label_list)" ] }, { @@ -207,7 +207,12 @@ "pygments_lexer": "ipython3", "version": "3.8.16" }, - "orig_nbformat": 4 + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "fb6f4ceeabb9a733f366948eb80109f83aedf798cc984df1e68fb411adb27d58" + } + } }, "nbformat": 4, "nbformat_minor": 2 diff --git a/examples/mnist_add/mnist_add_example.ipynb b/examples/mnist_add/mnist_add_example.ipynb index 83b52b4..baa2bcf 100644 --- a/examples/mnist_add/mnist_add_example.ipynb +++ b/examples/mnist_add/mnist_add_example.ipynb @@ -17,8 +17,8 @@ "from abl.abducer.kb import KBBase, prolog_KB\n", "\n", "from abl.utils.plog import logger\n", - "from abl.models.basic_model import BasicModel\n", - "from abl.models.wabl_models import WABLBasicModel\n", + "from abl.models.basic_nn import BasicNN\n", + "from abl.models.abl_model import ABLModel\n", "\n", "from models.nn import LeNet5\n", "from datasets.get_mnist_add import get_mnist_add\n", @@ -90,9 +90,9 @@ "metadata": {}, "outputs": [], "source": [ - "# Initialize BasicModel\n", - "# The function of BasicModel is to wrap NN models into the form of an sklearn estimator\n", - "base_model = BasicModel(\n", + "# Initialize BasicNN\n", + "# The function of BasicNN is to wrap NN models into the form of an sklearn estimator\n", + "base_model = BasicNN(\n", " cls,\n", " criterion,\n", " optimizer,\n", @@ -110,7 +110,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Use WABL model to join two parts" + "### Use ABL model to join two parts" ] }, { @@ -119,10 +119,10 @@ "metadata": {}, "outputs": [], "source": [ - "# Initialize WABL model\n", - "# The main function of the WABL model is to serialize data and \n", + "# Initialize ABL model\n", + "# The main function of the ABL model is to serialize data and \n", "# provide a unified interface for different machine learning models\n", - "model = WABLBasicModel(base_model, kb.pseudo_label_list)" + "model = ABLModel(base_model, kb.pseudo_label_list)" ] }, { @@ -192,7 +192,12 @@ "pygments_lexer": "ipython3", "version": "3.8.16" }, - "orig_nbformat": 4 + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "fb6f4ceeabb9a733f366948eb80109f83aedf798cc984df1e68fb411adb27d58" + } + } }, "nbformat": 4, "nbformat_minor": 2