Browse Source

[MNT] rename modules in abl.learning

pull/3/head
Gao Enhao 3 years ago
parent
commit
2e73cfd9b0
5 changed files with 740 additions and 28 deletions
  1. +137
    -0
      abl/learning/abl_model.py
  2. +560
    -0
      abl/learning/basic_nn.py
  3. +13
    -8
      examples/hed/hed_example.ipynb
  4. +15
    -10
      examples/hwf/hwf_example.ipynb
  5. +15
    -10
      examples/mnist_add/mnist_add_example.ipynb

+ 137
- 0
abl/learning/abl_model.py View File

@@ -0,0 +1,137 @@
# coding: utf-8
# ================================================================#
# Copyright (C) 2020 Freecss All rights reserved.
#
# File Name :models.py
# Author :freecss
# Email :karlfreecss@gmail.com
# Created Date :2020/04/02
# Description :
#
# ================================================================#
from itertools import chain
from typing import List, Any


def get_part_data(X, i):
return list(map(lambda x: x[i], X))


def merge_data(X):
ret_mark = list(map(lambda x: len(x), X))
ret_X = list(chain(*X))
return ret_X, ret_mark


def reshape_data(Y, marks):
begin_mark = 0
ret_Y = []
for mark in marks:
end_mark = begin_mark + mark
ret_Y.append(Y[begin_mark:end_mark])
begin_mark = end_mark
return ret_Y


class ABLModel:
"""
Serialize data and provide a unified interface for different machine learning models.

Parameters
----------
base_model : Machine Learning Model
The base model to use for training and prediction.
pseudo_label_list : List[Any]
A list of pseudo labels to use for training.

Attributes
----------
cls_list : List[Any]
A list of classifiers.
pseudo_label_list : List[Any]
A list of pseudo labels to use for training.
mapping : dict
A dictionary mapping pseudo labels to integers.
remapping : dict
A dictionary mapping integers to pseudo labels.

Methods
-------
predict(X: List[List[Any]]) -> dict
Predict the class labels and probabilities for the given data.
valid(X: List[List[Any]], Y: List[Any]) -> float
Calculate the accuracy score for the given data.
train(X: List[List[Any]], Y: List[Any])
Train the model on the given data.
"""
def __init__(self, base_model, pseudo_label_list: List[Any]):
self.cls_list = []
self.cls_list.append(base_model)

self.pseudo_label_list = pseudo_label_list
self.mapping = dict(zip(pseudo_label_list, list(range(len(pseudo_label_list)))))
self.remapping = dict(
zip(list(range(len(pseudo_label_list))), pseudo_label_list)
)

def predict(self, X: List[List[Any]]) -> dict:
"""
Predict the class labels and probabilities for the given data.

Parameters
----------
X : List[List[Any]]
The data to predict on.

Returns
-------
dict
A dictionary containing the predicted class labels and probabilities.
"""
data_X, marks = merge_data(X)
prob = self.cls_list[0].predict_proba(X=data_X)
_cls = prob.argmax(axis=1)
cls = list(map(lambda x: self.remapping[x], _cls))

prob = reshape_data(prob, marks)
cls = reshape_data(cls, marks)

return {"cls": cls, "prob": prob}

def valid(self, X: List[List[Any]], Y: List[Any]) -> float:
"""
Calculate the accuracy for the given data.

Parameters
----------
X : List[List[Any]]
The data to calculate the accuracy on.
Y : List[Any]
The true class labels for the given data.

Returns
-------
float
The accuracy score for the given data.
"""
data_X, _ = merge_data(X)
_data_Y, _ = merge_data(Y)
data_Y = list(map(lambda y: self.mapping[y], _data_Y))
score = self.cls_list[0].score(X=data_X, y=data_Y)
return score

def train(self, X: List[List[Any]], Y: List[Any]):
"""
Train the model on the given data.

Parameters
----------
X : List[List[Any]]
The data to train on.
Y : List[Any]
The true class labels for the given data.
"""
data_X, _ = merge_data(X)
_data_Y, _ = merge_data(Y)
data_Y = list(map(lambda y: self.mapping[y], _data_Y))
self.cls_list[0].fit(X=data_X, y=data_Y)

+ 560
- 0
abl/learning/basic_nn.py View File

@@ -0,0 +1,560 @@
# coding: utf-8
# ================================================================#
# Copyright (C) 2020 Freecss All rights reserved.
#
# File Name :basic_model.py
# Author :freecss
# Email :karlfreecss@gmail.com
# Created Date :2020/11/21
# Description :
#
# ================================================================#

import sys

sys.path.append("..")

import torch
import numpy
from torch.utils.data import Dataset, DataLoader

import os
from multiprocessing import Pool
from typing import List, Any, T, Tuple, Optional, Callable


class BasicDataset(Dataset):
def __init__(self, X: List[Any], Y: List[Any]):
"""Initialize a basic dataset.

Parameters
----------
X : List[Any]
A list of objects representing the input data.
Y : List[Any]
A list of objects representing the output data.
"""
self.X = X
self.Y = Y

def __len__(self):
"""Return the length of the dataset.

Returns
-------
int
The length of the dataset.
"""
return len(self.X)

def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""Get an item from the dataset.

Parameters
----------
index : int
The index of the item to retrieve.

Returns
-------
Tuple[Any, Any]
A tuple containing the input and output data at the specified index.
"""
if index >= len(self):
raise ValueError("index range error")

img = self.X[index]
label = self.Y[index]

return (img, label)


class XYDataset(Dataset):
def __init__(self, X: List[Any], Y: List[int], transform: Callable[..., Any] = None):
"""
Initialize the dataset used for classification task.

Parameters
----------
X : List[Any]
The input data.
Y : List[int]
The target data.
transform : Callable[..., Any], optional
A function/transform that takes in an object and returns a transformed version. Defaults to None.
"""
self.X = X
self.Y = torch.LongTensor(Y)

self.n_sample = len(X)
self.transform = transform

def __len__(self) -> int:
"""
Return the length of the dataset.

Returns
-------
int
The length of the dataset.
"""
return len(self.X)

def __getitem__(self, index: int) -> Tuple[Any, torch.Tensor]:
"""
Get the item at the given index.

Parameters
----------
index : int
The index of the item to get.

Returns
-------
Tuple[Any, torch.Tensor]
A tuple containing the object and its label.
"""
if index >= len(self):
raise ValueError("index range error")

img = self.X[index]
if self.transform is not None:
img = self.transform(img)

label = self.Y[index]

return (img, label)


class FakeRecorder:
def __init__(self):
pass

def print(self, *x):
pass


class BasicNN:
"""
Wrap NN models into the form of an sklearn estimator

Parameters
----------
model : torch.nn.Module
The PyTorch model to be trained or used for prediction.
criterion : torch.nn.Module
The loss function used for training.
optimizer : torch.nn.Module
The optimizer used for training.
device : torch.device, optional
The device on which the model will be trained or used for prediction, by default torch.decive("cpu").
batch_size : int, optional
The batch size used for training, by default 1.
num_epochs : int, optional
The number of epochs used for training, by default 1.
stop_loss : Optional[float], optional
The loss value at which to stop training, by default 0.01.
num_workers : int, optional
The number of workers used for loading data, by default 0.
save_interval : Optional[int], optional
The interval at which to save the model during training, by default None.
save_dir : Optional[str], optional
The directory in which to save the model during training, by default None.
transform : Callable[..., Any], optional
A function/transform that takes in an object and returns a transformed version. Defaults to None.
collate_fn : Callable[[List[T]], Any], optional
The function used to collate data, by default None.
recorder : Any, optional
The recorder used to record training progress, by default None.

Attributes
----------
model : torch.nn.Module
The PyTorch model to be trained or used for prediction.
batch_size : int
The batch size used for training.
num_epochs : int
The number of epochs used for training.
stop_loss : Optional[float]
The loss value at which to stop training.
num_workers : int
The number of workers used for loading data.
criterion : torch.nn.Module
The loss function used for training.
optimizer : torch.nn.Module
The optimizer used for training.
transform : Callable[..., Any]
The transformation function used for data augmentation.
device : torch.device
The device on which the model will be trained or used for prediction.
recorder : Any
The recorder used to record training progress.
save_interval : Optional[int]
The interval at which to save the model during training.
save_dir : Optional[str]
The directory in which to save the model during training.
collate_fn : Callable[[List[T]], Any]
The function used to collate data.

Methods
-------
fit(data_loader=None, X=None, y=None)
Train the model.
train_epoch(data_loader)
Train the model for one epoch.
predict(data_loader=None, X=None, print_prefix="")
Predict the class of the input data.
predict_proba(data_loader=None, X=None, print_prefix="")
Predict the probability of each class for the input data.
val(data_loader=None, X=None, y=None, print_prefix="")
Validate the model.
score(data_loader=None, X=None, y=None, print_prefix="")
Score the model.
_data_loader(X, y=None)
Generate the data_loader.
save(epoch_id, save_dir="")
Save the model.
load(epoch_id, load_dir="")
Load the model.
"""

def __init__(
self,
model: torch.nn.Module,
criterion: torch.nn.Module,
optimizer: torch.nn.Module,
device: torch.device = torch.device("cpu"),
batch_size: int = 1,
num_epochs: int = 1,
stop_loss: Optional[float] = 0.01,
num_workers: int = 0,
save_interval: Optional[int] = None,
save_dir: Optional[str] = None,
transform: Callable[..., Any] = None,
collate_fn: Callable[[List[T]], Any] = None,
recorder=None,
):

self.model = model.to(device)

self.batch_size = batch_size
self.num_epochs = num_epochs
self.stop_loss = stop_loss
self.num_workers = num_workers

self.criterion = criterion
self.optimizer = optimizer
self.transform = transform
self.device = device

if recorder is None:
recorder = FakeRecorder()
self.recorder = recorder

self.save_interval = save_interval
self.save_dir = save_dir
self.collate_fn = collate_fn

def _fit(self, data_loader, n_epoch, stop_loss):
recorder = self.recorder
recorder.print("model fitting")

min_loss = 1e10
for epoch in range(n_epoch):
loss_value = self.train_epoch(data_loader)
recorder.print(f"{epoch}/{n_epoch} model training loss is {loss_value}")
if min_loss < 0 or loss_value < min_loss:
min_loss = loss_value
if self.save_interval is not None and (epoch + 1) % self.save_interval == 0:
if self.save_dir is None:
raise ValueError(
"save_dir should not be None if save_interval is not None"
)
self.save(epoch + 1, self.save_dir)
if stop_loss is not None and loss_value < stop_loss:
break
recorder.print("Model fitted, minimal loss is ", min_loss)
return loss_value

def fit(
self, data_loader: DataLoader = None, X: List[Any] = None, y: List[int] = None
) -> float:
"""
Train the model.

Parameters
----------
data_loader : DataLoader, optional
The data loader used for training, by default None
X : List[Any], optional
The input data, by default None
y : List[int], optional
The target data, by default None

Returns
-------
float
The loss value of the trained model.
"""
if data_loader is None:
data_loader = self._data_loader(X, y)
return self._fit(data_loader, self.num_epochs, self.stop_loss)

def train_epoch(self, data_loader: DataLoader):
"""
Train the model for one epoch.

Parameters
----------
data_loader : DataLoader
The data loader used for training.

Returns
-------
float
The loss value of the trained model.
"""
model = self.model
criterion = self.criterion
optimizer = self.optimizer
device = self.device

model.train()

total_loss, total_num = 0.0, 0
for data, target in data_loader:
data, target = data.to(device), target.to(device)
out = model(data)
loss = criterion(out, target)

optimizer.zero_grad()
loss.backward()
optimizer.step()

total_loss += loss.item() * data.size(0)
total_num += data.size(0)

return total_loss / total_num

def _predict(self, data_loader):
model = self.model
device = self.device

model.eval()

with torch.no_grad():
results = []
for data, _ in data_loader:
data = data.to(device)
out = model(data)
results.append(out)

return torch.cat(results, axis=0)

def predict(
self,
data_loader: DataLoader = None,
X: List[Any] = None,
print_prefix: str = "",
) -> numpy.ndarray:
"""
Predict the class of the input data.

Parameters
----------
data_loader : DataLoader, optional
The data loader used for prediction, by default None
X : List[Any], optional
The input data, by default None
print_prefix : str, optional
The prefix used for printing, by default ""

Returns
-------
numpy.ndarray
The predicted class of the input data.
"""
recorder = self.recorder
recorder.print("Start Predict Class ", print_prefix)

if data_loader is None:
data_loader = self._data_loader(X)
return self._predict(data_loader).argmax(axis=1).cpu().numpy()

def predict_proba(
self,
data_loader: DataLoader = None,
X: List[Any] = None,
print_prefix: str = "",
) -> numpy.ndarray:
"""
Predict the probability of each class for the input data.

Parameters
----------
data_loader : DataLoader, optional
The data loader used for prediction, by default None
X : List[Any], optional
The input data, by default None
print_prefix : str, optional
The prefix used for printing, by default ""

Returns
-------
numpy.ndarray
The predicted probability of each class for the input data.
"""
recorder = self.recorder
recorder.print("Start Predict Probability ", print_prefix)

if data_loader is None:
data_loader = self._data_loader(X)
return self._predict(data_loader).softmax(axis=1).cpu().numpy()

def _score(self, data_loader):
model = self.model
criterion = self.criterion
device = self.device

model.eval()

total_correct_num, total_num, total_loss = 0, 0, 0.0

with torch.no_grad():
for data, target in data_loader:
data, target = data.to(device), target.to(device)

out = model(data)

if len(out.shape) > 1:
correct_num = sum(target == out.argmax(axis=1)).item()
else:
correct_num = sum(target == (out > 0.5)).item()
loss = criterion(out, target)
total_loss += loss.item() * data.size(0)

total_correct_num += correct_num
total_num += data.size(0)

mean_loss = total_loss / total_num
accuracy = total_correct_num / total_num

return mean_loss, accuracy

def score(
self,
data_loader: DataLoader = None,
X: List[Any] = None,
y: List[int] = None,
print_prefix: str = "",
) -> float:
"""
Validate the model.

Parameters
----------
data_loader : DataLoader, optional
The data loader used for scoring, by default None
X : List[Any], optional
The input data, by default None
y : List[int], optional
The target data, by default None
print_prefix : str, optional
The prefix used for printing, by default ""

Returns
-------
float
The accuracy of the model.
"""
recorder = self.recorder
recorder.print("Start validation ", print_prefix)

if data_loader is None:
data_loader = self._data_loader(X, y)
mean_loss, accuracy = self._score(data_loader)
recorder.print(
"[%s] mean loss: %f, accuray: %f" % (print_prefix, mean_loss, accuracy)
)
return accuracy

def _data_loader(
self,
X: List[Any],
y: List[int] = None,
) -> DataLoader:
"""
Generate data_loader for user provided data.

Parameters
----------
X : List[Any]
The input data.
y : List[int], optional
The target data, by default None

Returns
-------
DataLoader
The data loader.
"""
collate_fn = self.collate_fn
transform = self.transform

if y is None:
y = [0] * len(X)
dataset = XYDataset(X, y, transform=transform)
sampler = None
data_loader = DataLoader(
dataset,
batch_size=self.batch_size,
shuffle=False,
sampler=sampler,
num_workers=int(self.num_workers),
collate_fn=collate_fn,
)
return data_loader

def save(self, epoch_id: int, save_dir: str = ""):
"""
Save the model and the optimizer.

Parameters
----------
epoch_id : int
The epoch id.
save_dir : str, optional
The directory to save the model, by default ""
"""
recorder = self.recorder
if not os.path.exists(save_dir):
os.makedirs(save_dir)
recorder.print("Saving model and opter")
save_path = os.path.join(save_dir, str(epoch_id) + "_net.pth")
torch.save(self.model.state_dict(), save_path)

save_path = os.path.join(save_dir, str(epoch_id) + "_opt.pth")
torch.save(self.optimizer.state_dict(), save_path)

def load(self, epoch_id: int, load_dir: str = ""):
"""
Load the model and the optimizer.

Parameters
----------
epoch_id : int
The epoch id.
load_dir : str, optional
The directory to load the model, by default ""
"""
recorder = self.recorder
recorder.print("Loading model and opter")
load_path = os.path.join(load_dir, str(epoch_id) + "_net.pth")
self.model.load_state_dict(torch.load(load_path))

load_path = os.path.join(load_dir, str(epoch_id) + "_opt.pth")
self.optimizer.load_state_dict(torch.load(load_path))


if __name__ == "__main__":
pass

+ 13
- 8
examples/hed/hed_example.ipynb View File

@@ -18,8 +18,8 @@
"from abl.abducer.kb import prolog_KB\n",
"\n",
"from abl.utils.plog import logger\n",
"from abl.models.basic_model import BasicModel\n",
"from abl.models.wabl_models import WABLBasicModel\n",
"from abl.models.basic_nn import BasicNN\n",
"from abl.models.abl_model import ABLModel\n",
"from abl.utils.utils import reform_idx\n",
"\n",
"from models.nn import SymbolNet\n",
@@ -172,9 +172,9 @@
"metadata": {},
"outputs": [],
"source": [
"# Initialize BasicModel\n",
"# The function of BasicModel is to wrap NN models into the form of an sklearn estimator\n",
"base_model = BasicModel(\n",
"# Initialize BasicNN\n",
"# The function of BasicNN is to wrap NN models into the form of an sklearn estimator\n",
"base_model = BasicNN(\n",
" cls,\n",
" criterion,\n",
" optimizer,\n",
@@ -192,7 +192,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### Use WABL model to join two parts"
"### Use ABL model to join two parts"
]
},
{
@@ -201,7 +201,7 @@
"metadata": {},
"outputs": [],
"source": [
"model = WABLBasicModel(base_model, kb.pseudo_label_list)"
"model = ABLModel(base_model, kb.pseudo_label_list)"
]
},
{
@@ -262,7 +262,12 @@
"pygments_lexer": "ipython3",
"version": "3.8.16"
},
"orig_nbformat": 4
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "fb6f4ceeabb9a733f366948eb80109f83aedf798cc984df1e68fb411adb27d58"
}
}
},
"nbformat": 4,
"nbformat_minor": 2


+ 15
- 10
examples/hwf/hwf_example.ipynb View File

@@ -18,8 +18,8 @@
"from abl.abducer.kb import KBBase\n",
"\n",
"from abl.utils.plog import logger\n",
"from abl.models.basic_model import BasicModel\n",
"from abl.models.wabl_models import WABLBasicModel\n",
"from abl.models.basic_nn import BasicNN\n",
"from abl.models.abl_model import ABLModel\n",
"\n",
"from models.nn import SymbolNet\n",
"from datasets.get_hwf import get_hwf\n",
@@ -111,9 +111,9 @@
"metadata": {},
"outputs": [],
"source": [
"# Initialize BasicModel\n",
"# The function of BasicModel is to wrap NN models into the form of an sklearn estimator\n",
"base_model = BasicModel(\n",
"# Initialize BasicNN\n",
"# The function of BasicNN is to wrap NN models into the form of an sklearn estimator\n",
"base_model = BasicNN(\n",
" cls,\n",
" criterion,\n",
" optimizer,\n",
@@ -131,7 +131,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### Use WABL model to join two parts"
"### Use ABL model to join two parts"
]
},
{
@@ -140,10 +140,10 @@
"metadata": {},
"outputs": [],
"source": [
"# Initialize WABL model\n",
"# The main function of the WABL model is to serialize data and \n",
"# Initialize ABL model\n",
"# The main function of the ABL model is to serialize data and \n",
"# provide a unified interface for different machine learning models\n",
"model = WABLBasicModel(base_model, kb.pseudo_label_list)"
"model = ABLModel(base_model, kb.pseudo_label_list)"
]
},
{
@@ -207,7 +207,12 @@
"pygments_lexer": "ipython3",
"version": "3.8.16"
},
"orig_nbformat": 4
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "fb6f4ceeabb9a733f366948eb80109f83aedf798cc984df1e68fb411adb27d58"
}
}
},
"nbformat": 4,
"nbformat_minor": 2


+ 15
- 10
examples/mnist_add/mnist_add_example.ipynb View File

@@ -17,8 +17,8 @@
"from abl.abducer.kb import KBBase, prolog_KB\n",
"\n",
"from abl.utils.plog import logger\n",
"from abl.models.basic_model import BasicModel\n",
"from abl.models.wabl_models import WABLBasicModel\n",
"from abl.models.basic_nn import BasicNN\n",
"from abl.models.abl_model import ABLModel\n",
"\n",
"from models.nn import LeNet5\n",
"from datasets.get_mnist_add import get_mnist_add\n",
@@ -90,9 +90,9 @@
"metadata": {},
"outputs": [],
"source": [
"# Initialize BasicModel\n",
"# The function of BasicModel is to wrap NN models into the form of an sklearn estimator\n",
"base_model = BasicModel(\n",
"# Initialize BasicNN\n",
"# The function of BasicNN is to wrap NN models into the form of an sklearn estimator\n",
"base_model = BasicNN(\n",
" cls,\n",
" criterion,\n",
" optimizer,\n",
@@ -110,7 +110,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### Use WABL model to join two parts"
"### Use ABL model to join two parts"
]
},
{
@@ -119,10 +119,10 @@
"metadata": {},
"outputs": [],
"source": [
"# Initialize WABL model\n",
"# The main function of the WABL model is to serialize data and \n",
"# Initialize ABL model\n",
"# The main function of the ABL model is to serialize data and \n",
"# provide a unified interface for different machine learning models\n",
"model = WABLBasicModel(base_model, kb.pseudo_label_list)"
"model = ABLModel(base_model, kb.pseudo_label_list)"
]
},
{
@@ -192,7 +192,12 @@
"pygments_lexer": "ipython3",
"version": "3.8.16"
},
"orig_nbformat": 4
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "fb6f4ceeabb9a733f366948eb80109f83aedf798cc984df1e68fb411adb27d58"
}
}
},
"nbformat": 4,
"nbformat_minor": 2


Loading…
Cancel
Save