Browse Source

Merge pull request #4 from AbductiveLearning/parrial_ab_data

Add abstract data interface to bridge, dataset, evaluation and learning.
pull/1/head
Huang Yuxuan GitHub 2 years ago
parent
commit
541777f968
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 1787 additions and 527 deletions
  1. +29
    -17
      abl/bridge/base_bridge.py
  2. +87
    -74
      abl/bridge/simple_bridge.py
  3. +2
    -1
      abl/dataset/__init__.py
  4. +2
    -1
      abl/dataset/bridge_dataset.py
  5. +2
    -1
      abl/dataset/classification_dataset.py
  6. +56
    -0
      abl/dataset/prediction_dataset.py
  7. +2
    -1
      abl/dataset/regression_dataset.py
  8. +1
    -1
      abl/evaluation/__init__.py
  9. +2
    -2
      abl/evaluation/base_metric.py
  10. +10
    -11
      abl/evaluation/semantics_metric.py
  11. +4
    -3
      abl/evaluation/symbol_metric.py
  12. +35
    -63
      abl/learning/abl_model.py
  13. +50
    -39
      abl/learning/basic_nn.py
  14. +2
    -2
      abl/reasoning/__init__.py
  15. +118
    -101
      abl/reasoning/kb.py
  16. +107
    -139
      abl/reasoning/reasoner.py
  17. +2
    -0
      abl/structures/__init__.py
  18. +629
    -0
      abl/structures/base_data_element.py
  19. +305
    -0
      abl/structures/list_data.py
  20. +2
    -1
      abl/utils/__init__.py
  21. +104
    -0
      abl/utils/cache.py
  22. +57
    -7
      abl/utils/utils.py
  23. +12
    -12
      examples/hed/hed_bridge.py
  24. +142
    -27
      examples/hwf/hwf_example.ipynb
  25. +25
    -24
      examples/mnist_add/mnist_add_example.ipynb

+ 29
- 17
abl/bridge/base_bridge.py View File

@@ -1,52 +1,64 @@
from abc import ABCMeta, abstractmethod
from typing import Any, List, Tuple
from typing import Any, List, Optional, Tuple, Union

from ..learning import ABLModel
from ..reasoning import ReasonerBase
from ..structures import ListData

DataSet = Tuple[List[List[Any]], Optional[List[List[Any]]], List[List[Any]]]

class BaseBridge(metaclass=ABCMeta):

def __init__(self, model: ABLModel, abducer: ReasonerBase) -> None:
class BaseBridge(metaclass=ABCMeta):
def __init__(self, model: ABLModel, reasoner: ReasonerBase) -> None:
if not isinstance(model, ABLModel):
raise TypeError("Expected an ABLModel")
if not isinstance(abducer, ReasonerBase):
raise TypeError("Expected an ReasonerBase")
raise TypeError(
"Expected an instance of ABLModel, but received type: {}".format(
type(model)
)
)
if not isinstance(reasoner, ReasonerBase):
raise TypeError(
"Expected an instance of ReasonerBase, but received type: {}".format(
type(reasoner)
)
)

self.model = model
self.abducer = abducer
self.reasoner = reasoner

@abstractmethod
def predict(self, X: List[List[Any]]) -> Tuple[List[List[Any]], List[List[Any]]]:
def predict(
self, data_samples: ListData
) -> Tuple[List[List[Any]], List[List[Any]]]:
"""Placeholder for predict labels from input."""
pass

@abstractmethod
def abduce_pseudo_label(self, pseudo_label: List[List[Any]]) -> List[List[Any]]:
def abduce_pseudo_label(self, data_samples: ListData) -> List[List[Any]]:
"""Placeholder for abduce pseudo labels."""
pass

@abstractmethod
def idx_to_pseudo_label(self, idx: List[List[Any]]) -> List[List[Any]]:
def idx_to_pseudo_label(self, data_samples: ListData) -> List[List[Any]]:
"""Placeholder for map label space to symbol space."""
pass

@abstractmethod
def pseudo_label_to_idx(self, pseudo_label: List[List[Any]]) -> List[List[Any]]:
def pseudo_label_to_idx(self, data_samples: ListData) -> List[List[Any]]:
"""Placeholder for map symbol space to label space."""
pass
@abstractmethod
def train(self, train_data):
def train(self, train_data: Union[ListData, DataSet]):
"""Placeholder for train loop of ABductive Learning."""
pass

@abstractmethod
def test(self, test_data):
def valid(self, valid_data: Union[ListData, DataSet]) -> None:
"""Placeholder for model test."""
pass

@abstractmethod
def valid(self, valid_data):
def test(self, test_data: Union[ListData, DataSet]) -> None:
"""Placeholder for model validation."""
pass

+ 87
- 74
abl/bridge/simple_bridge.py View File

@@ -1,104 +1,119 @@
from ..learning import ABLModel
from ..reasoning import ReasonerBase
from ..evaluation import BaseMetric
from .base_bridge import BaseBridge
from typing import List, Union, Any, Tuple, Dict, Optional
import os.path as osp
from typing import Any, Dict, List, Optional, Tuple, Union

from numpy import ndarray

from torch.utils.data import DataLoader
from ..dataset import BridgeDataset
from ..utils.logger import print_log
from ..evaluation import BaseMetric
from ..learning import ABLModel
from ..reasoning import ReasonerBase
from ..structures import ListData
from ..utils import print_log
from .base_bridge import BaseBridge, DataSet


class SimpleBridge(BaseBridge):
def __init__(
self,
model: ABLModel,
abducer: ReasonerBase,
reasoner: ReasonerBase,
metric_list: List[BaseMetric],
) -> None:
super().__init__(model, abducer)
super().__init__(model, reasoner)
self.metric_list = metric_list

def predict(self, X) -> Tuple[List[List[Any]], ndarray]:
pred_res = self.model.predict(X)
pred_idx, pred_prob = pred_res["label"], pred_res["prob"]
return pred_idx, pred_prob
# TODO: add reasoner.mapping to the property of SimpleBridge

def predict(self, data_samples: ListData) -> Tuple[List[ndarray], List[ndarray]]:
self.model.predict(data_samples)
return data_samples.pred_idx, data_samples.pred_prob

def abduce_pseudo_label(
self,
pred_prob: ndarray,
pred_pseudo_label: List[List[Any]],
Y: List[Any],
data_samples: ListData,
max_revision: int = -1,
require_more_revision: int = 0,
) -> List[List[Any]]:
return self.abducer.batch_abduce(pred_prob, pred_pseudo_label, Y, max_revision, require_more_revision)
self.reasoner.batch_abduce(data_samples, max_revision, require_more_revision)
return data_samples.abduced_pseudo_label

def idx_to_pseudo_label(
self, idx: List[List[Any]], mapping: Dict = None
self, data_samples: ListData, mapping: Optional[Dict] = None
) -> List[List[Any]]:
if mapping is None:
mapping = self.abducer.mapping
return [[mapping[_idx] for _idx in sub_list] for sub_list in idx]
mapping = self.reasoner.mapping
pred_idx = data_samples.pred_idx
data_samples.pred_pseudo_label = [
[mapping[_idx] for _idx in sub_list] for sub_list in pred_idx
]
return data_samples.pred_pseudo_label

def pseudo_label_to_idx(
self, pseudo_label: List[List[Any]], mapping: Dict = None
self, data_samples: ListData, mapping: Optional[Dict] = None
) -> List[List[Any]]:
if mapping is None:
mapping = self.abducer.remapping
return [
[mapping[_pseudo_label] for _pseudo_label in sub_list]
for sub_list in pseudo_label
mapping = self.reasoner.remapping
abduced_idx = [
[mapping[_abduced_pseudo_label] for _abduced_pseudo_label in sub_list]
for sub_list in data_samples.abduced_pseudo_label
]
data_samples.abduced_idx = abduced_idx
return data_samples.abduced_idx

def data_preprocess(self, X: List[Any], gt_pseudo_label: List[Any], Y: List[Any]) -> ListData:
data_samples = ListData()

data_samples.X = X
data_samples.gt_pseudo_label = gt_pseudo_label
data_samples.Y = Y

return data_samples

def train(
self,
train_data: Tuple[List[List[Any]], Optional[List[List[Any]]], List[List[Any]]],
epochs: int = 50,
batch_size: Union[int, float] = -1,
train_data: Union[ListData, DataSet],
loops: int = 50,
segment_size: Union[int, float] = -1,
eval_interval: int = 1,
save_interval: Optional[int] = None,
save_dir: Optional[str] = None,
):
dataset = BridgeDataset(*train_data)
data_loader = DataLoader(
dataset,
batch_size=batch_size,
collate_fn=lambda data_list: [list(data) for data in zip(*data_list)],
)

for epoch in range(epochs):
for seg_idx, (X, Z, Y) in enumerate(data_loader):
pred_idx, pred_prob = self.predict(X)
pred_pseudo_label = self.idx_to_pseudo_label(pred_idx)
abduced_pseudo_label = self.abduce_pseudo_label(
pred_prob, pred_pseudo_label, Y
)
abduced_label = self.pseudo_label_to_idx(abduced_pseudo_label)
loss = self.model.train(X, abduced_label)
if isinstance(train_data, ListData):
data_samples = train_data
else:
data_samples = self.data_preprocess(*train_data)

for loop in range(loops):
for seg_idx in range((len(data_samples) - 1) // segment_size + 1):
sub_data_samples = data_samples[
seg_idx * segment_size : (seg_idx + 1) * segment_size
]
self.predict(sub_data_samples)
self.idx_to_pseudo_label(sub_data_samples)
self.abduce_pseudo_label(sub_data_samples)
self.pseudo_label_to_idx(sub_data_samples)
loss = self.model.train(sub_data_samples)

print_log(
f"Epoch(train) [{epoch + 1}] [{(seg_idx + 1):3}/{len(data_loader)}] model loss is {loss:.5f}",
f"loop(train) [{loop + 1}/{loops}] segment(train) [{(seg_idx + 1)}/{(len(data_samples) - 1) // segment_size + 1}] model loss is {loss:.5f}",
logger="current",
)

if (epoch + 1) % eval_interval == 0 or epoch == epochs - 1:
print_log(f"Evaluation start: Epoch(val) [{epoch}]", logger="current")
if (loop + 1) % eval_interval == 0 or loop == loops - 1:
print_log(f"Evaluation start: loop(val) [{loop + 1}]", logger="current")
self.valid(train_data)

def _valid(self, data_loader):
for X, Z, Y in data_loader:
pred_idx, pred_prob = self.predict(X)
pred_pseudo_label = self.idx_to_pseudo_label(pred_idx)
data_samples = dict(
pred_idx=pred_idx,
pred_prob=pred_prob,
pred_pseudo_label=pred_pseudo_label,
gt_pseudo_label=Z,
Y=Y,
logic_forward=self.abducer.kb.logic_forward,
)
for metric in self.metric_list:
metric.process(data_samples)
if save_interval is not None and ((loop + 1) % save_interval == 0 or loop == loops - 1):
print_log(f"Saving model: loop(save) [{loop + 1}]", logger="current")
self.model.save(
save_path=osp.join(save_dir, f"model_checkpoint_loop_{loop + 1}.pth")
)

def _valid(self, data_samples: ListData) -> None:
self.predict(data_samples)
self.idx_to_pseudo_label(data_samples)

for metric in self.metric_list:
metric.process(data_samples)

res = dict()
for metric in self.metric_list:
@@ -108,14 +123,12 @@ class SimpleBridge(BaseBridge):
msg += k + f": {v:.3f} "
print_log(msg, logger="current")

def valid(self, valid_data, batch_size=1000):
dataset = BridgeDataset(*valid_data)
data_loader = DataLoader(
dataset,
batch_size=batch_size,
collate_fn=lambda data_list: [list(data) for data in zip(*data_list)],
)
self._valid(data_loader)

def test(self, test_data, batch_size=1000):
self.valid(test_data, batch_size)
def valid(self, valid_data: Union[ListData, DataSet]) -> None:
if not isinstance(valid_data, ListData):
data_samples = self.data_preprocess(*valid_data)
else:
data_samples = valid_data
self._valid(data_samples)

def test(self, test_data: Union[ListData, DataSet]) -> None:
self.valid(test_data)

+ 2
- 1
abl/dataset/__init__.py View File

@@ -1,3 +1,4 @@
from .bridge_dataset import BridgeDataset
from .classification_dataset import ClassificationDataset
from .regression_dataset import RegressionDataset
from .prediction_dataset import PredictionDataset
from .regression_dataset import RegressionDataset

+ 2
- 1
abl/dataset/bridge_dataset.py View File

@@ -1,5 +1,6 @@
from typing import Any, List, Tuple

from torch.utils.data import Dataset
from typing import List, Any, Tuple


class BridgeDataset(Dataset):


+ 2
- 1
abl/dataset/classification_dataset.py View File

@@ -1,6 +1,7 @@
from typing import Any, Callable, List, Tuple

import torch
from torch.utils.data import Dataset
from typing import List, Any, Tuple, Callable


class ClassificationDataset(Dataset):


+ 56
- 0
abl/dataset/prediction_dataset.py View File

@@ -0,0 +1,56 @@
from typing import Any, Callable, List, Tuple

import torch
from torch.utils.data import Dataset


class PredictionDataset(Dataset):
def __init__(self, X: List[Any], transform: Callable[..., Any] = None):
"""
Initialize the dataset used for classification task.

Parameters
----------
X : List[Any]
The input data.
transform : Callable[..., Any], optional
A function/transform that takes in an object and returns a transformed version. Defaults to None.
"""
if not isinstance(X, list):
raise ValueError("X should be of type list.")

self.X = X
self.transform = transform

def __len__(self) -> int:
"""
Return the length of the dataset.

Returns
-------
int
The length of the dataset.
"""
return len(self.X)

def __getitem__(self, index: int) -> Tuple[Any, torch.Tensor]:
"""
Get the item at the given index.

Parameters
----------
index : int
The index of the item to get.

Returns
-------
Tuple[Any, torch.Tensor]
A tuple containing the object and its label.
"""
if index >= len(self):
raise ValueError("index range error")

x = self.X[index]
if self.transform is not None:
x = self.transform(x)
return x

+ 2
- 1
abl/dataset/regression_dataset.py View File

@@ -1,6 +1,7 @@
from typing import Any, List, Tuple

import torch
from torch.utils.data import Dataset
from typing import List, Any, Tuple


class RegressionDataset(Dataset):


+ 1
- 1
abl/evaluation/__init__.py View File

@@ -1,3 +1,3 @@
from .base_metric import BaseMetric
from .symbol_metric import SymbolMetric
from .semantics_metric import SemanticsMetric
from .symbol_metric import SymbolMetric

+ 2
- 2
abl/evaluation/base_metric.py View File

@@ -1,8 +1,8 @@
import logging
from abc import ABCMeta, abstractmethod
from typing import Any, List, Optional, Sequence
from ..utils import print_log

import logging
from ..utils import print_log


class BaseMetric(metaclass=ABCMeta):


+ 10
- 11
abl/evaluation/semantics_metric.py View File

@@ -1,25 +1,24 @@
from typing import Optional, Sequence

from ..reasoning import KBBase
from .base_metric import BaseMetric

class ABLMetric():
pass

class SemanticsMetric(BaseMetric):
def __init__(self, prefix: Optional[str] = None) -> None:
def __init__(self, kb: KBBase = None, prefix: Optional[str] = None) -> None:
super().__init__(prefix)
self.kb = kb

def process(self, data_samples: Sequence[dict]) -> None:
pred_pseudo_label = data_samples["pred_pseudo_label"]
gt_Y = data_samples["Y"]
logic_forward = data_samples["logic_forward"]

for pred_z, y in zip(pred_pseudo_label, gt_Y):
if logic_forward(pred_z) == y:
pred_pseudo_label_list = data_samples.pred_pseudo_label
y_list = data_samples.Y
for pred_pseudo_label, y in zip(pred_pseudo_label_list, y_list):
if self.kb._check_equal(self.kb.logic_forward(pred_pseudo_label), y):
self.results.append(1)
else:
self.results.append(0)
def compute_metrics(self, results: list) -> dict:
metrics = dict()
metrics["semantics_accuracy"] = sum(results) / len(results)
return metrics
return metrics

+ 4
- 3
abl/evaluation/symbol_metric.py View File

@@ -1,4 +1,5 @@
from typing import Optional, Sequence, Callable
from typing import Optional, Sequence

from .base_metric import BaseMetric


@@ -7,9 +8,9 @@ class SymbolMetric(BaseMetric):
super().__init__(prefix)

def process(self, data_samples: Sequence[dict]) -> None:
pred_pseudo_label = data_samples["pred_pseudo_label"]
pred_pseudo_label = data_samples.pred_pseudo_label

gt_pseudo_label = data_samples["gt_pseudo_label"]
gt_pseudo_label = data_samples.gt_pseudo_label

if not len(pred_pseudo_label) == len(gt_pseudo_label):
raise ValueError("lengthes of pred_pseudo_label and gt_pseudo_label should be equal")


+ 35
- 63
abl/learning/abl_model.py View File

@@ -10,8 +10,10 @@
#
# ================================================================#
import pickle
from utils import flatten, reform_idx
from typing import List, Any, Optional
from typing import Any, Dict

from ..structures import ListData
from ..utils import reform_list


class ABLModel:
@@ -30,7 +32,7 @@ class ABLModel:

Methods
-------
predict(X: List[List[Any]], mapping: Optional[dict] = None) -> dict
predict(X: List[List[Any]], mapping: Optional[Dict] = None) -> Dict
Predict the labels and probabilities for the given data.
valid(X: List[List[Any]], Y: List[Any]) -> float
Calculate the accuracy score for the given data.
@@ -42,20 +44,13 @@ class ABLModel:
Load the model from a file.
"""

def __init__(self, base_model) -> None:
self.classifier_list = []
self.classifier_list.append(base_model)
def __init__(self, base_model: Any) -> None:
if not (hasattr(base_model, "fit") and hasattr(base_model, "predict")):
raise NotImplementedError("The base_model should implement fit and predict methods.")

if not (
hasattr(base_model, "fit")
and hasattr(base_model, "predict")
and hasattr(base_model, "score")
):
raise NotImplementedError(
"base_model should have fit, predict and score methods."
)
self.base_model = base_model

def predict(self, X: List[List[Any]], mapping: Optional[dict] = None) -> dict:
def predict(self, data_samples: ListData) -> Dict:
"""
Predict the labels and probabilities for the given data.

@@ -63,53 +58,29 @@ class ABLModel:
----------
X : List[List[Any]]
The data to predict on.
mapping : Optional[dict], optional
A mapping dictionary to map labels to their original values, by default None.

Returns
-------
dict
A dictionary containing the predicted labels and probabilities.
"""
model = self.classifier_list[0]
data_X = flatten(X)
model = self.base_model
data_X = data_samples.flatten("X")
if hasattr(model, "predict_proba"):
prob = model.predict_proba(X=data_X)
label = prob.argmax(axis=1)
prob = reform_idx(prob, X)
prob = reform_list(prob, data_samples.X)
else:
prob = None
label = model.predict(X=data_X)
label = reform_list(label, data_samples.X)

if mapping is not None:
label = [mapping[y] for y in label]

label = reform_idx(label, X)
data_samples.pred_idx = label
data_samples.pred_prob = prob

return {"label": label, "prob": prob}

def valid(self, X: List[List[Any]], Y: List[Any]) -> float:
"""
Calculate the accuracy for the given data.

Parameters
----------
X : List[List[Any]]
The data to calculate the accuracy on.
Y : List[Any]
The true labels for the given data.

Returns
-------
float
The accuracy score for the given data.
"""
data_X = flatten(X)
data_Y = flatten(Y)
score = self.classifier_list[0].score(X=data_X, y=data_Y)
return score

def train(self, X: List[List[Any]], Y: List[Any]) -> float:
def train(self, data_samples: ListData) -> float:
"""
Train the model on the given data.

@@ -125,29 +96,30 @@ class ABLModel:
float
The loss value of the trained model.
"""
data_X = flatten(X)
data_Y = flatten(Y)
return self.classifier_list[0].fit(X=data_X, y=data_Y)
data_X = data_samples.flatten("X")
data_y = data_samples.flatten("abduced_idx")
return self.base_model.fit(X=data_X, y=data_y)

def _model_operation(self, operation: str, *args, **kwargs):
model = self.classifier_list[0]
model = self.base_model
if hasattr(model, operation):
method = getattr(model, operation)
method(*args, **kwargs)
else:
try:
if not f"{operation}_path" in kwargs.keys():
raise ValueError(f"'{operation}_path' should not be None")
if operation == "save":
with open(kwargs["save_path"], 'wb') as file:
pickle.dump(model, file, protocol=pickle.HIGHEST_PROTOCOL)
elif operation == "load":
with open(kwargs["load_path"], 'rb') as file:
self.classifier_list[0] = pickle.load(file)
except:
raise NotImplementedError(
f"{type(model).__name__} object doesn't have the {operation} method"
)
if not f"{operation}_path" in kwargs.keys():
raise ValueError(f"'{operation}_path' should not be None")
else:
try:
if operation == "save":
with open(kwargs["save_path"], "wb") as file:
pickle.dump(model, file, protocol=pickle.HIGHEST_PROTOCOL)
elif operation == "load":
with open(kwargs["load_path"], "rb") as file:
self.base_model = pickle.load(file)
except:
raise NotImplementedError(
f"{type(model).__name__} object doesn't have the {operation} method and the default pickle-based {operation} method failed."
)

def save(self, *args, **kwargs) -> None:
"""


+ 50
- 39
abl/learning/basic_nn.py View File

@@ -10,14 +10,16 @@
#
# ================================================================#

import torch
import os
import logging
from typing import Any, Callable, List, Optional, T, Tuple

import numpy
import torch
from torch.utils.data import DataLoader
from ..utils.logger import print_log
from ..dataset import ClassificationDataset

import os
from typing import List, Any, T, Optional, Callable, Tuple
from ..dataset import ClassificationDataset, PredictionDataset
from ..utils.logger import print_log


class BasicNN:
@@ -64,7 +66,8 @@ class BasicNN:
num_workers: int = 0,
save_interval: Optional[int] = None,
save_dir: Optional[str] = None,
transform: Callable[..., Any] = None,
train_transform: Callable[..., Any] = None,
test_transform: Callable[..., Any] = None,
collate_fn: Callable[[List[T]], Any] = None,
) -> None:
self.model = model.to(device)
@@ -77,10 +80,19 @@ class BasicNN:
self.num_workers = num_workers
self.save_interval = save_interval
self.save_dir = save_dir
self.transform = transform
self.train_transform = train_transform
self.test_transform = test_transform
self.collate_fn = collate_fn

def _fit(self, data_loader) -> float:
if self.train_transform is not None and self.test_transform is None:
print_log(
"Transform used in the training phase will be used in prediction.",
"current",
level=logging.WARNING,
)
self.test_transform = self.train_transform

def _fit(self, data_loader: DataLoader) -> float:
"""
Internal method to fit the model on data for n epochs, with early stopping.

@@ -99,9 +111,7 @@ class BasicNN:
loss_value = self.train_epoch(data_loader)
if self.save_interval is not None and (epoch + 1) % self.save_interval == 0:
if self.save_dir is None:
raise ValueError(
"save_dir should not be None if save_interval is not None."
)
raise ValueError("save_dir should not be None if save_interval is not None.")
self.save(epoch + 1)
if self.stop_loss is not None and loss_value < self.stop_loss:
break
@@ -170,7 +180,7 @@ class BasicNN:

return total_loss / total_num

def _predict(self, data_loader) -> torch.Tensor:
def _predict(self, data_loader: DataLoader) -> torch.Tensor:
"""
Internal method to predict the outputs given a DataLoader.

@@ -191,16 +201,14 @@ class BasicNN:

with torch.no_grad():
results = []
for data, _ in data_loader:
for data in data_loader:
data = data.to(device)
out = model(data)
results.append(out)

return torch.cat(results, axis=0)

def predict(
self, data_loader: DataLoader = None, X: List[Any] = None
) -> numpy.ndarray:
def predict(self, data_loader: DataLoader = None, X: List[Any] = None) -> numpy.ndarray:
"""
Predict the class of the input data.

@@ -218,12 +226,16 @@ class BasicNN:
"""

if data_loader is None:
data_loader = self._data_loader(X)
dataset = PredictionDataset(X, self.test_transform)
data_loader = DataLoader(
dataset,
batch_size=self.batch_size,
num_workers=int(self.num_workers),
collate_fn=self.collate_fn,
)
return self._predict(data_loader).argmax(axis=1).cpu().numpy()

def predict_proba(
self, data_loader: DataLoader = None, X: List[Any] = None
) -> numpy.ndarray:
def predict_proba(self, data_loader: DataLoader = None, X: List[Any] = None) -> numpy.ndarray:
"""
Predict the probability of each class for the input data.

@@ -241,10 +253,16 @@ class BasicNN:
"""

if data_loader is None:
data_loader = self._data_loader(X)
dataset = PredictionDataset(X, self.test_transform)
data_loader = DataLoader(
dataset,
batch_size=self.batch_size,
num_workers=int(self.num_workers),
collate_fn=self.collate_fn,
)
return self._predict(data_loader).softmax(axis=1).cpu().numpy()

def _score(self, data_loader) -> Tuple[float, float]:
def _score(self, data_loader: DataLoader) -> Tuple[float, float]:
"""
Internal method to compute loss and accuracy for the data provided through a DataLoader.

@@ -313,16 +331,10 @@ class BasicNN:
if data_loader is None:
data_loader = self._data_loader(X, y)
mean_loss, accuracy = self._score(data_loader)
print_log(
f"mean loss: {mean_loss:.3f}, accuray: {accuracy:.3f}", logger="current"
)
print_log(f"mean loss: {mean_loss:.3f}, accuray: {accuracy:.3f}", logger="current")
return accuracy

def _data_loader(
self,
X: List[Any],
y: List[int] = None,
) -> DataLoader:
def _data_loader(self, X: List[Any], y: List[int] = None, shuffle: bool = True) -> DataLoader:
"""
Generate a DataLoader for user-provided input and target data.

@@ -346,11 +358,11 @@ class BasicNN:
if not (len(y) == len(X)):
raise ValueError("X and y should have equal length.")

dataset = ClassificationDataset(X, y, transform=self.transform)
dataset = ClassificationDataset(X, y, transform=self.train_transform)
data_loader = DataLoader(
dataset,
batch_size=self.batch_size,
shuffle=True,
shuffle=shuffle,
num_workers=int(self.num_workers),
collate_fn=self.collate_fn,
)
@@ -368,14 +380,13 @@ class BasicNN:
The path to save the model, by default None.
"""
if self.save_dir is None and save_path is None:
raise ValueError(
"'save_dir' and 'save_path' should not be None simultaneously."
)
raise ValueError("'save_dir' and 'save_path' should not be None simultaneously.")

if save_path is None:
save_path = os.path.join(
self.save_dir, f"model_checkpoint_epoch_{epoch_id}.pth"
)
if save_path is not None:
if not os.path.exists(os.path.dirname(save_path)):
os.makedirs(os.path.dirname(save_path))
else:
save_path = os.path.join(self.save_dir, f"model_checkpoint_epoch_{epoch_id}.pth")
if not os.path.exists(self.save_dir):
os.makedirs(self.save_dir)



+ 2
- 2
abl/reasoning/__init__.py View File

@@ -1,2 +1,2 @@
from .reasoner import ReasonerBase
from .kb import KBBase, prolog_KB
from .kb import KBBase, GroundKB, PrologKB
from .reasoner import ReasonerBase

+ 118
- 101
abl/reasoning/kb.py View File

@@ -9,7 +9,8 @@ from functools import lru_cache
import numpy as np
import pyswip

from abl.utils.utils import flatten, reform_idx, hamming_dist, to_hashable, restore_from_hashable
from ..utils.utils import flatten, reform_list, hamming_dist, to_hashable, restore_from_hashable
from ..utils.cache import abl_cache


class KBBase(ABC):
@@ -21,35 +22,46 @@ class KBBase(ABC):
pseudo_label_list : list
List of possible pseudo labels.
max_err : float, optional
The upper tolerance limit when comparing the similarity between a candidate's logical
result. This is only applicable when the logical result is of a numerical type.
This is particularly relevant for regression problems where exact matches might not be
feasible. Defaults to 1e-10.
The upper tolerance limit when comparing the similarity between a candidate's logical
result. This is only applicable when the logical result is of a numerical type.
This is particularly relevant for regression problems where exact matches might not be
feasible. Defaults to 1e-10.
use_cache : bool, optional
Whether to use a cache for previously abduced candidates to speed up subsequent
Whether to use a cache for previously abduced candidates to speed up subsequent
operations. Defaults to True.
Notes
-----
Users should inherit from this base class to build their own knowledge base. For the
user-build KB (an inherited subclass), it's only required for the user to provide the
`pseudo_label_list` and override the `logic_forward` function (specifying how to
perform logical reasoning). After that, other operations (e.g. how to perform abductive
reasoning) will be automatically set up.
Users should inherit from this base class to build their own knowledge base. For the
user-build KB (an inherited subclass), it's only required for the user to provide the
`pseudo_label_list` and override the `logic_forward` function (specifying how to
perform logical reasoning). After that, other operations (e.g. how to perform abductive
reasoning) will be automatically set up.
"""
def __init__(self, pseudo_label_list, max_err=1e-10, use_cache=True):

def __init__(
self,
pseudo_label_list,
max_err=1e-10,
use_cache=True,
key_func=to_hashable,
max_cache_size=4096,
):
if not isinstance(pseudo_label_list, list):
raise TypeError("pseudo_label_list should be list")
self.pseudo_label_list = pseudo_label_list
self.max_err = max_err
self.use_cache = use_cache

self.use_cache = use_cache
self.key_func = key_func
self.max_cache_size = max_cache_size

@abstractmethod
def logic_forward(self, pseudo_label):
"""
How to perform (deductive) logical reasoning, i.e. matching each pseudo label to
How to perform (deductive) logical reasoning, i.e. matching each pseudo label to
their logical result. Users are required to provide this.
Parameters
----------
pred_pseudo_label : List[Any]
@@ -70,23 +82,22 @@ class KBBase(ABC):
max_revision_num : int
The upper limit on the number of revisions.
require_more_revision : int, optional
Specifies additional number of revisions permitted beyond the minimum required.
Specifies additional number of revisions permitted beyond the minimum required.
Defaults to 0.

Returns
-------
List[List[Any]]
A list of candidates, i.e. revised pseudo labels that are consistent with the
A list of candidates, i.e. revised pseudo labels that are consistent with the
knowledge base.
"""
if self.use_cache:
return self._abduce_by_search_cache(to_hashable(pred_pseudo_label),
to_hashable(y),
max_revision_num, require_more_revision)
else:
return self._abduce_by_search(pred_pseudo_label, y,
max_revision_num, require_more_revision)
# if self.use_cache:
# return self._abduce_by_search_cache(to_hashable(pred_pseudo_label),
# to_hashable(y),
# max_revision_num, require_more_revision)
# else:
return self._abduce_by_search(pred_pseudo_label, y, max_revision_num, require_more_revision)

def _check_equal(self, logic_result, y):
"""
Check whether the logical result of a candidate is equal to the ground truth
@@ -94,12 +105,12 @@ class KBBase(ABC):
"""
if logic_result == None:
return False
if isinstance(logic_result, (int, float)) and isinstance(y, (int, float)):
return abs(logic_result - y) <= self.max_err
else:
return logic_result == y
def revise_at_idx(self, pred_pseudo_label, y, revision_idx):
"""
Revise the predicted pseudo label at specified index positions.
@@ -125,7 +136,7 @@ class KBBase(ABC):

def _revision(self, revision_num, pred_pseudo_label, y):
"""
For a specified number of pseudo label to revise, iterate through all possible
For a specified number of pseudo label to revise, iterate through all possible
indices to find any candidates that are consistent with the knowledge base.
"""
new_candidates = []
@@ -136,12 +147,13 @@ class KBBase(ABC):
new_candidates.extend(candidates)
return new_candidates

def _abduce_by_search(self, pred_pseudo_label, y, max_revision_num, require_more_revision):
@abl_cache()
def _abduce_by_search(self, pred_pseudo_label, y, max_revision_num, require_more_revision):
"""
Perform abductive reasoning by exhastive search. Specifically, begin with 0 and
continuously increase the number of pseudo labels to revise, until candidates
Perform abductive reasoning by exhastive search. Specifically, begin with 0 and
continuously increase the number of pseudo labels to revise, until candidates
that are consistent with the knowledge base are found.
Parameters
----------
pred_pseudo_label : List[Any]
@@ -151,16 +163,16 @@ class KBBase(ABC):
max_revision_num : int
The upper limit on the number of revisions.
require_more_revision : int
If larger than 0, then after having found any candidates consistent with the
knowledge base, continue to increase the number pseudo labels to revise to
If larger than 0, then after having found any candidates consistent with the
knowledge base, continue to increase the number pseudo labels to revise to
get more possible consistent candidates.

Returns
-------
List[List[Any]]
A list of candidates, i.e. revised pseudo label that are consistent with the
A list of candidates, i.e. revised pseudo label that are consistent with the
knowledge base.
"""
"""
candidates = []
for revision_num in range(len(pred_pseudo_label) + 1):
if revision_num == 0 and self._check_equal(self.logic_forward(pred_pseudo_label), y):
@@ -173,20 +185,22 @@ class KBBase(ABC):
if revision_num >= max_revision_num:
return []

for revision_num in range(min_revision_num + 1, min_revision_num + require_more_revision + 1):
for revision_num in range(
min_revision_num + 1, min_revision_num + require_more_revision + 1
):
if revision_num > max_revision_num:
return candidates
candidates.extend(self._revision(revision_num, pred_pseudo_label, y))
return candidates
@lru_cache(maxsize=4096)
def _abduce_by_search_cache(self, pred_pseudo_label, y, max_revision_num, require_more_revision):
"""
`_abduce_by_search` with cache.
"""
pred_pseudo_label = restore_from_hashable(pred_pseudo_label)
y = restore_from_hashable(y)
return self._abduce_by_search(pred_pseudo_label, y, max_revision_num, require_more_revision)
# @abl_cache(max_size=4096)
# def _abduce_by_search_cache(self, pred_pseudo_label, y, max_revision_num, require_more_revision):
# """
# `_abduce_by_search` with cache.
# """
# pred_pseudo_label = restore_from_hashable(pred_pseudo_label)
# y = restore_from_hashable(y)
# return self._abduce_by_search(pred_pseudo_label, y, max_revision_num, require_more_revision)

def __repr__(self):
return (
@@ -195,13 +209,13 @@ class KBBase(ABC):
f"max_err={self.max_err!r}, "
f"use_cache={self.use_cache!r}."
)
class GroundKB(KBBase):
"""
Knowledge base with a ground KB (GKB). Ground KB is a knowledge base prebuilt upon
class initialization, storing all potential candidates along with their respective
logical result. Ground KB can accelerate abductive reasoning in `abduce_candidates`.
Knowledge base with a ground KB (GKB). Ground KB is a knowledge base prebuilt upon
class initialization, storing all potential candidates along with their respective
logical result. Ground KB can accelerate abductive reasoning in `abduce_candidates`.

Parameters
----------
@@ -211,15 +225,16 @@ class GroundKB(KBBase):
List of possible lengths of pseudo label.
max_err : float, optional
Refer to class `KBBase`.
Notes
-----
Users can also inherit from this class to build their own knowledge base. Similar
to `KBBase`, users are only required to provide the `pseudo_label_list` and override
Users can also inherit from this class to build their own knowledge base. Similar
to `KBBase`, users are only required to provide the `pseudo_label_list` and override
the `logic_forward` function. Additionally, users should provide the `GKB_len_list`.
After that, other operations (e.g. auto-construction of GKB, and how to perform
After that, other operations (e.g. auto-construction of GKB, and how to perform
abductive reasoning) will be automatically set up.
"""

def __init__(self, pseudo_label_list, GKB_len_list, max_err=1e-10):
super().__init__(pseudo_label_list, max_err)
if not isinstance(GKB_len_list, list):
@@ -229,7 +244,6 @@ class GroundKB(KBBase):
X, Y = self._get_GKB()
for x, y in zip(X, Y):
self.GKB.setdefault(len(x), defaultdict(list))[y].append(x)

def _get_XY_list(self, args):
pre_x, post_x_it = args[0], args[1]
@@ -259,21 +273,21 @@ class GroundKB(KBBase):
part_X, part_Y = zip(*XY_list)
X.extend(part_X)
Y.extend(part_Y)
if Y and isinstance(Y[0], (int, float)):
if Y and isinstance(Y[0], (int, float)):
X, Y = zip(*sorted(zip(X, Y), key=lambda pair: pair[1]))
return X, Y
def abduce_candidates(self, pred_pseudo_label, y, max_revision_num, require_more_revision=0):
"""
Perform abductive reasoning by directly retrieving consistent candidates from
the prebuilt GKB. In this way, the time-consuming exhaustive search can be
Perform abductive reasoning by directly retrieving consistent candidates from
the prebuilt GKB. In this way, the time-consuming exhaustive search can be
avoided.
This is an overridden function. For more information about the parameters and
This is an overridden function. For more information about the parameters and
returns, refer to the function of the same name in class `KBBase`.
"""
if self.GKB == {} or len(pred_pseudo_label) not in self.GKB_len_list:
return []
all_candidates = self._find_candidate_GKB(pred_pseudo_label, y)
if len(all_candidates) == 0:
return []
@@ -284,29 +298,30 @@ class GroundKB(KBBase):
idxs = np.where(cost_list <= revision_num)[0]
candidates = [all_candidates[idx] for idx in idxs]
return candidates
def _find_candidate_GKB(self, pred_pseudo_label, y):
"""
Retrieve consistent candidates from the prebuilt GKB. For numerical logical results,
return all candidates whose logical results fall within the
Retrieve consistent candidates from the prebuilt GKB. For numerical logical results,
return all candidates whose logical results fall within the
[y - max_err, y + max_err] range.
"""
if isinstance(y, (int, float)):
potential_candidates = self.GKB[len(pred_pseudo_label)]
key_list = list(potential_candidates.keys())
low_key = bisect.bisect_left(key_list, y - self.max_err)
high_key = bisect.bisect_right(key_list, y + self.max_err)

all_candidates = [candidate
for key in key_list[low_key:high_key]
for candidate in potential_candidates[key]]
all_candidates = [
candidate
for key in key_list[low_key:high_key]
for candidate in potential_candidates[key]
]
return all_candidates
else:
return self.GKB[len(pred_pseudo_label)][y]

def __repr__(self):
return (
f"{self.__class__.__name__} is a KB with "
@@ -321,78 +336,80 @@ class GroundKB(KBBase):
class PrologKB(KBBase):
"""
Knowledge base provided by a Prolog (.pl) file.
Parameters
----------
pseudo_label_list : list
Refer to class `KBBase`.
pl_file :
Prolog file containing the KB.
pl_file :
Prolog file containing the KB.
max_err : float, optional
Refer to class `KBBase`.
Notes
-----
Users can instantiate this class to build their own knowledge base. During the
Users can instantiate this class to build their own knowledge base. During the
instantiation, users are only required to provide the `pseudo_label_list` and `pl_file`.
To use the default logic forward and abductive reasoning methods in this class, in the
Prolog (.pl) file, there needs to be a rule which is strictly formatted as
To use the default logic forward and abductive reasoning methods in this class, in the
Prolog (.pl) file, there needs to be a rule which is strictly formatted as
`logic_forward(Pseudo_labels, Res).`, e.g., `logic_forward([A,B], C) :- C is A+B`.
For specifics, refer to the `logic_forward` and `get_query_string` functions in this
For specifics, refer to the `logic_forward` and `get_query_string` functions in this
class. Users are also welcome to override related functions for more flexible support.
"""

def __init__(self, pseudo_label_list, pl_file):
super().__init__(pseudo_label_list)
self.pl_file = pl_file
self.prolog = pyswip.Prolog()
if not os.path.exists(self.pl_file):
raise FileNotFoundError(f"The Prolog file {self.pl_file} does not exist.")
self.prolog.consult(self.pl_file)

def logic_forward(self, pseudo_labels):
"""
Consult prolog with the query `logic_forward(pseudo_labels, Res).`, and set the
returned `Res` as the logical results. To use this default function, there must be
a Prolog `log_forward` method in the pl file to perform logical. reasoning.
Consult prolog with the query `logic_forward(pseudo_labels, Res).`, and set the
returned `Res` as the logical results. To use this default function, there must be
a Prolog `log_forward` method in the pl file to perform logical. reasoning.
Otherwise, users would override this function.
"""
result = list(self.prolog.query("logic_forward(%s, Res)." % pseudo_labels))[0]['Res']
if result == 'true':
result = list(self.prolog.query("logic_forward(%s, Res)." % pseudo_labels))[0]["Res"]
if result == "true":
return True
elif result == 'false':
elif result == "false":
return False
return result
def _revision_pred_pseudo_label(self, pred_pseudo_label, revision_idx):
import re

revision_pred_pseudo_label = pred_pseudo_label.copy()
revision_pred_pseudo_label = flatten(revision_pred_pseudo_label)
for idx in revision_idx:
revision_pred_pseudo_label[idx] = 'P' + str(idx)
revision_pred_pseudo_label = reform_idx(revision_pred_pseudo_label, pred_pseudo_label)
revision_pred_pseudo_label[idx] = "P" + str(idx)
revision_pred_pseudo_label = reform_list(revision_pred_pseudo_label, pred_pseudo_label)
regex = r"'P\d+'"
return re.sub(regex, lambda x: x.group().replace("'", ""), str(revision_pred_pseudo_label))
def get_query_string(self, pred_pseudo_label, y, revision_idx):
"""
Consult prolog with `logic_forward([kept_labels, Revise_labels], Res).`, and set
the returned `Revise_labels` together with the kept labels as the candidates. This is
Consult prolog with `logic_forward([kept_labels, Revise_labels], Res).`, and set
the returned `Revise_labels` together with the kept labels as the candidates. This is
a default fuction for demo, users would override this function to adapt to their own
Prolog file.
Prolog file.
"""
query_string = "logic_forward("
query_string += self._revision_pred_pseudo_label(pred_pseudo_label, revision_idx)
key_is_none_flag = y is None or (type(y) == list and y[0] is None)
query_string += ",%s)." % y if not key_is_none_flag else ")."
return query_string
def revise_at_idx(self, pred_pseudo_label, y, revision_idx):
"""
Revise the predicted pseudo label at specified index positions by querying Prolog.
This is an overridden function. For more information about the parameters, refer to
This is an overridden function. For more information about the parameters, refer to
the function of the same name in class `KBBase`.
"""
candidates = []
@@ -404,7 +421,7 @@ class PrologKB(KBBase):
candidate = pred_pseudo_label.copy()
for i, idx in enumerate(revision_idx):
candidate[idx] = c[i]
candidate = reform_idx(candidate, save_pred_pseudo_label)
candidate = reform_list(candidate, save_pred_pseudo_label)
candidates.append(candidate)
return candidates

@@ -414,4 +431,4 @@ class PrologKB(KBBase):
f"pseudo_label_list={self.pseudo_label_list!r}, "
f"defined by "
f"Prolog file {self.pl_file!r}."
)
)

+ 107
- 139
abl/reasoning/reasoner.py View File

@@ -1,9 +1,9 @@
import numpy as np
from zoopt import Dimension, Objective, Parameter, Opt
from abl.utils.utils import (
from ..utils.utils import (
confidence_dist,
flatten,
reform_idx,
reform_list,
hamming_dist,
)

@@ -191,7 +191,7 @@ class ReasonerBase:
return max_revision
def abduce(
self, pred_prob, pred_pseudo_label, y, max_revision=-1, require_more_revision=0
self, data_sample, max_revision=-1, require_more_revision=0
):
"""
Perform abductive reasoning on the given prediction data.
@@ -219,9 +219,13 @@ class ReasonerBase:
A revised pseudo label through abductive reasoning, which is consistent with the
knowledge base.
"""
symbol_num = len(flatten(pred_pseudo_label))
symbol_num = data_sample.elements_num("pred_pseudo_label")
max_revision_num = self._get_max_revision_num(max_revision, symbol_num)

pred_pseudo_label = data_sample.pred_pseudo_label
pred_prob = data_sample.pred_prob
y = data_sample.Y
if self.use_zoopt:
solution = self.zoopt_get_solution(
symbol_num, pred_pseudo_label, pred_prob, y, max_revision_num
@@ -237,20 +241,18 @@ class ReasonerBase:
return candidate

def batch_abduce(
self, pred_probs, pred_pseudo_labels, Ys, max_revision=-1, require_more_revision=0
self, data_samples, max_revision=-1, require_more_revision=0
):
"""
Perform abductive reasoning on the given prediction data in batches.
For detailed information, refer to `abduce`.
"""
return [
self.abduce(
pred_prob, pred_pseudo_label, Y, max_revision, require_more_revision
)
for pred_prob, pred_pseudo_label, Y in zip(
pred_probs, pred_pseudo_labels, Ys
)
abduced_pseudo_label = [
self.abduce(data_sample, max_revision, require_more_revision)
for data_sample in data_samples
]
data_samples.abduced_pseudo_label = abduced_pseudo_label
return abduced_pseudo_label

# def _batch_abduce_helper(self, args):
# z, prob, y, max_revision, require_more_revision = args
@@ -273,12 +275,11 @@ class ReasonerBase:

if __name__ == "__main__":
from kb import KBBase, GroundKB, PrologKB

prob1 = [[[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]]
from abl.structures import ListData
prob2 = [[[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]]
################################
# Test for MNIST Add reasoning #
################################

class AddKB(KBBase):
def __init__(self, pseudo_label_list=list(range(10)),
@@ -288,38 +289,54 @@ if __name__ == "__main__":
def logic_forward(self, nums):
return sum(nums)
class AddGroundKB(GroundKB):
class AddGroundKB(GroundKB, AddKB):
def __init__(self, pseudo_label_list=list(range(10)),
GKB_len_list=[2]):
super().__init__(pseudo_label_list, GKB_len_list)


def logic_forward(self, nums):
return sum(nums)
def logic_forward(self, nums):
return sum(nums)
def test_add(reasoner):
res = reasoner.batch_abduce(prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0)
print(res)
res = reasoner.batch_abduce(prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0)
print(res)
res = reasoner.batch_abduce(prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0)
print(res)
res = reasoner.batch_abduce(prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0)
# favor 1 in first one
prob1 = [[0, 0.99, 0, 0, 0, 0, 0, 0.01, 0, 0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]
# favor 7 in first one
prob2 = [[0, 0.01, 0, 0, 0, 0, 0, 0.99, 0, 0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]
data_samples_add = ListData()
data_samples_add.pred_pseudo_label = [[1, 1], [1, 1], [1, 1], [1, 1]]
data_samples_add.pred_prob = [prob1, prob2, prob1, prob2]
data_samples_add.Y = [8, 8, 17, 10]
res = reasoner.batch_abduce(data_samples_add, max_revision=1, require_more_revision=0)
print(res)
res = reasoner.batch_abduce(prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0)
res = reasoner.batch_abduce(data_samples_add, max_revision=1, require_more_revision=1)
print(res)
res = reasoner.batch_abduce(data_samples_add, max_revision=2, require_more_revision=0)
print(res)
res = reasoner.batch_abduce(data_samples_add, max_revision=2, require_more_revision=1)
print(res) # due to more revision allowed, for the 4th, it will favor [7,3] over [1,9]
print()

print("AddKB with GKB:")
print("AddGroundKB:")
kb = AddGroundKB()
reasoner = ReasonerBase(kb, "confidence")
test_add(reasoner)

print("AddKB without GKB:")
print("AddKB:")
kb = AddKB()
reasoner = ReasonerBase(kb, "confidence")
test_add(reasoner)

print("AddKB without GKB, no cache")
print("AddKB, no cache")
kb = AddKB(use_cache=False)
reasoner = ReasonerBase(kb, "confidence")
test_add(reasoner)
@@ -337,45 +354,20 @@ if __name__ == "__main__":
)
reasoner = ReasonerBase(kb, "confidence", use_zoopt=True)
test_add(reasoner)

print("AddKB with multiple inputs at once:")
multiple_prob = [[
[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
],
[
[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
]]

kb = AddKB()
reasoner = ReasonerBase(kb, "confidence")
res = reasoner.batch_abduce(
multiple_prob,
[[1, 1], [1, 2]],
[4, 8],
max_revision=2,
require_more_revision=0,
)
print(res)
res = reasoner.batch_abduce(
multiple_prob,
[[1, 1], [1, 2]],
[4, 8],
max_revision=2,
require_more_revision=1,
)
print(res)
print()

################################
#### Test for HWF reasoning ####
################################
class HwfKB(KBBase):
def __init__(
self,
pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9",
"+", "-", "times", "div"],
max_err=1e-3,
use_cache=False,
):
super().__init__(pseudo_label_list, max_err)
super().__init__(pseudo_label_list, max_err, use_cache)

def _valid_candidate(self, formula):
if len(formula) % 2 == 0:
@@ -395,7 +387,7 @@ if __name__ == "__main__":
formula = [mapping[f] for f in formula]
return eval("".join(formula))
class HwfGroundKB(GroundKB):
class HwfGroundKB(GroundKB, HwfKB):
def __init__(
self,
pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9",
@@ -405,6 +397,17 @@ if __name__ == "__main__":
):
super().__init__(pseudo_label_list, GKB_len_list, max_err)


def _valid_candidate(self, formula):
if len(formula) % 2 == 0:
return False
for i in range(len(formula)):
if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]:
return False
if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]:
return False
return True
def _valid_candidate(self, formula):
if len(formula) % 2 == 0:
return False
@@ -415,6 +418,16 @@ if __name__ == "__main__":
return False
return True


def logic_forward(self, formula):
if not self._valid_candidate(formula):
return None
mapping = {str(i): str(i) for i in range(1, 10)}
mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"})
formula = [mapping[f] for f in formula]
return eval("".join(formula))
def logic_forward(self, formula):
if not self._valid_candidate(formula):
return None
@@ -424,87 +437,46 @@ if __name__ == "__main__":
return eval("".join(formula))
def test_hwf(reasoner):
res = reasoner.batch_abduce(
[None],
[["5", "+", "2"]],
[3],
max_revision=2,
require_more_revision=0,
)
data_samples_hwf = ListData()
data_samples_hwf.pred_pseudo_label = [["5", "+", "2"], ["5", "+", "9"], ["5", "+", "9"], ["5", "-", "8", "8", "8"]]
data_samples_hwf.pred_prob = [None, None, None, None]
data_samples_hwf.Y = [3, 64, 65, 3.17]
res = reasoner.batch_abduce(data_samples_hwf, max_revision=3, require_more_revision=0)
print(res)
res = reasoner.batch_abduce(
[None],
[["5", "+", "9"]],
[65],
max_revision=3,
require_more_revision=0,
)
res = reasoner.batch_abduce(data_samples_hwf, max_revision=0.5, require_more_revision=3)
print(res)
res = reasoner.batch_abduce(
[None],
[["5", "8", "8", "8", "8"]],
[3.17],
max_revision=5,
require_more_revision=3,
)
res = reasoner.batch_abduce(data_samples_hwf, max_revision=0.9, require_more_revision=0)
print(res)
print()
def test_hwf_multiple(reasoner, max_revisions):
res = reasoner.batch_abduce(
[None, None],
[["5", "+", "2"], ["5", "+", "9"]],
[3, 64],
max_revision=max_revisions[0],
require_more_revision=0,
)
print(res)
res = reasoner.batch_abduce(
[None, None],
[["5", "+", "2"], ["5", "+", "9"]],
[3, 64],
max_revision=max_revisions[1],
require_more_revision=0,
)
print(res)
res = reasoner.batch_abduce(
[None, None],
[["5", "+", "2"], ["5", "+", "9"]],
[3, 65],
max_revision=max_revisions[2],
require_more_revision=0,
)
print(res)
print()

print("HwfKB with GKB, max_err=0.1")
print("HwfGroundKB, max_err=0.1:")
kb = HwfGroundKB(GKB_len_list=[1, 3, 5], max_err=0.1)
reasoner = ReasonerBase(kb, "hamming")
test_hwf(reasoner)

print("HwfKB without GKB, max_err=0.1")
print("HwfKB, max_err=0.1:")
kb = HwfKB(max_err=0.1)
reasoner = ReasonerBase(kb, "hamming")
test_hwf(reasoner)

print("HwfKB with GKB, max_err=1")
print("HwfGroundKB, max_err=1:")
kb = HwfGroundKB(GKB_len_list=[1, 3, 5], max_err=1)
reasoner = ReasonerBase(kb, "hamming")
test_hwf(reasoner)

print("HwfKB without GKB, max_err=1")
print("HwfKB, max_err=1:")
kb = HwfKB(max_err=1)
reasoner = ReasonerBase(kb, "hamming")
test_hwf(reasoner)

print("HwfKB with multiple inputs at once:")
kb = HwfKB(max_err=0.1)
reasoner = ReasonerBase(kb, "hamming")
test_hwf_multiple(reasoner, max_revisions=[1,3,3])
print("max_revision is float")
test_hwf_multiple(reasoner, max_revisions=[0.5,0.9,0.9])

################################
#### Test for HED reasoning ####
################################
class HedKB(PrologKB):
def __init__(self, pseudo_label_list, pl_file):
super().__init__(pseudo_label_list, pl_file)
@@ -540,7 +512,7 @@ if __name__ == "__main__":
return candidate

def zoopt_revision_score(self, symbol_num, pred_res, pred_prob, y, sol):
all_revision_flag = reform_idx(sol.get_x(), pred_res)
all_revision_flag = reform_list(sol.get_x(), pred_res)
lefted_idxs = [i for i in range(len(pred_res))]
candidate_size = []
while lefted_idxs:
@@ -597,28 +569,24 @@ if __name__ == "__main__":
inconsist_exs2 = [[1, "+", 0, "=", 0], [1, "=", 1, "=", 0], [0, "=", 0, "=", 1, 1]]
rules = ["my_op([0], [0], [0])", "my_op([1], [1], [1, 0])"]

print("HedKB logic forward")
print(kb.logic_forward(consist_exs))
print("HedKB logic forward:")
print(kb.logic_forward(consist_exs), end=" ")
print(kb.logic_forward(inconsist_exs1), kb.logic_forward(inconsist_exs2))
print()
print("HedKB consist rule")
print(kb.consist_rule([1, "+", 1, "=", 1, 0], rules))
print("HedKB consist rule:")
print(kb.consist_rule([1, "+", 1, "=", 1, 0], rules), end=" ")
print(kb.consist_rule([1, "+", 1, "=", 1, 1], rules))
print()

data_sample_hed = ListData()
data_sample_hed.pred_pseudo_label = [consist_exs, inconsist_exs1, inconsist_exs2]
data_sample_hed.pred_prob = [[None] * len(consist_exs), [None] * len(inconsist_exs1), [None] * len(inconsist_exs2)]
data_sample_hed.Y = [[None] * len(consist_exs), [None] * len(inconsist_exs1), [None] * len(inconsist_exs2)]

print("HedReasoner abduce")
res = reasoner.abduce(
[[[None]]] * len(consist_exs), consist_exs, [None] * len(consist_exs)
)
print(res)
res = reasoner.abduce(
[[[None]]] * len(inconsist_exs1), inconsist_exs1, [None] * len(inconsist_exs1)
)
print(res)
res = reasoner.abduce(
[[[None]]] * len(inconsist_exs2), inconsist_exs2, [None] * len(inconsist_exs2)
)
print(res)
res = reasoner.batch_abduce(data_sample_hed)
for r in res:
print(r)
print()

print("HedReasoner abduce rules")


+ 2
- 0
abl/structures/__init__.py View File

@@ -0,0 +1,2 @@
from .base_data_element import BaseDataElement
from .list_data import ListData

+ 629
- 0
abl/structures/base_data_element.py View File

@@ -0,0 +1,629 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import Any, Iterator, Optional, Tuple, Type, Union

import numpy as np
import torch


class BaseDataElement:
"""A base data interface that supports Tensor-like and dict-like
operations.

A typical data elements refer to predicted results or ground truth labels
on a task, such as predicted bboxes, instance masks, semantic
segmentation masks, etc. Because groundtruth labels and predicted results
often have similar properties (for example, the predicted bboxes and the
groundtruth bboxes), MMEngine uses the same abstract data interface to
encapsulate predicted results and groundtruth labels, and it is recommended
to use different name conventions to distinguish them, such as using
``gt_instances`` and ``pred_instances`` to distinguish between labels and
predicted results. Additionally, we distinguish data elements at instance
level, pixel level, and label level. Each of these types has its own
characteristics. Therefore, MMEngine defines the base class
``BaseDataElement``, and implement ``InstanceData``, ``PixelData``, and
``LabelData`` inheriting from ``BaseDataElement`` to represent different
types of ground truth labels or predictions.

Another common data element is sample data. A sample data consists of input
data (such as an image) and its annotations and predictions. In general,
an image can have multiple types of annotations and/or predictions at the
same time (for example, both pixel-level semantic segmentation annotations
and instance-level detection bboxes annotations). All labels and
predictions of a training sample are often passed between Dataset, Model,
Visualizer, and Evaluator components. In order to simplify the interface
between components, we can treat them as a large data element and
encapsulate them. Such data elements are generally called XXDataSample in
the OpenMMLab. Therefore, Similar to `nn.Module`, the `BaseDataElement`
allows `BaseDataElement` as its attribute. Such a class generally
encapsulates all the data of a sample in the algorithm library, and its
attributes generally are various types of data elements. For example,
MMDetection is assigned by the BaseDataElement to encapsulate all the data
elements of the sample labeling and prediction of a sample in the
algorithm library.

The attributes in ``BaseDataElement`` are divided into two parts,
the ``metainfo`` and the ``data`` respectively.

- ``metainfo``: Usually contains the
information about the image such as filename,
image_shape, pad_shape, etc. The attributes can be accessed or
modified by dict-like or object-like operations, such as
``.`` (for data access and modification), ``in``, ``del``,
``pop(str)``, ``get(str)``, ``metainfo_keys()``,
``metainfo_values()``, ``metainfo_items()``, ``set_metainfo()`` (for
set or change key-value pairs in metainfo).

- ``data``: Annotations or model predictions are
stored. The attributes can be accessed or modified by
dict-like or object-like operations, such as
``.``, ``in``, ``del``, ``pop(str)``, ``get(str)``, ``keys()``,
``values()``, ``items()``. Users can also apply tensor-like
methods to all :obj:`torch.Tensor` in the ``data_fields``,
such as ``.cuda()``, ``.cpu()``, ``.numpy()``, ``.to()``,
``to_tensor()``, ``.detach()``.

Args:
metainfo (dict, optional): A dict contains the meta information
of single image, such as ``dict(img_shape=(512, 512, 3),
scale_factor=(1, 1, 1, 1))``. Defaults to None.
kwargs (dict, optional): A dict contains annotations of single image or
model predictions. Defaults to None.

Examples:
>>> import torch
>>> from mmengine.structures import BaseDataElement
>>> gt_instances = BaseDataElement()
>>> bboxes = torch.rand((5, 4))
>>> scores = torch.rand((5,))
>>> img_id = 0
>>> img_shape = (800, 1333)
>>> gt_instances = BaseDataElement(
... metainfo=dict(img_id=img_id, img_shape=img_shape),
... bboxes=bboxes, scores=scores)
>>> gt_instances = BaseDataElement(
... metainfo=dict(img_id=img_id, img_shape=(640, 640)))

>>> # new
>>> gt_instances1 = gt_instances.new(
... metainfo=dict(img_id=1, img_shape=(640, 640)),
... bboxes=torch.rand((5, 4)),
... scores=torch.rand((5,)))
>>> gt_instances2 = gt_instances1.new()

>>> # add and process property
>>> gt_instances = BaseDataElement()
>>> gt_instances.set_metainfo(dict(img_id=9, img_shape=(100, 100)))
>>> assert 'img_shape' in gt_instances.metainfo_keys()
>>> assert 'img_shape' in gt_instances
>>> assert 'img_shape' not in gt_instances.keys()
>>> assert 'img_shape' in gt_instances.all_keys()
>>> print(gt_instances.img_shape)
(100, 100)
>>> gt_instances.scores = torch.rand((5,))
>>> assert 'scores' in gt_instances.keys()
>>> assert 'scores' in gt_instances
>>> assert 'scores' in gt_instances.all_keys()
>>> assert 'scores' not in gt_instances.metainfo_keys()
>>> print(gt_instances.scores)
tensor([0.5230, 0.7885, 0.2426, 0.3911, 0.4876])
>>> gt_instances.bboxes = torch.rand((5, 4))
>>> assert 'bboxes' in gt_instances.keys()
>>> assert 'bboxes' in gt_instances
>>> assert 'bboxes' in gt_instances.all_keys()
>>> assert 'bboxes' not in gt_instances.metainfo_keys()
>>> print(gt_instances.bboxes)
tensor([[0.0900, 0.0424, 0.1755, 0.4469],
[0.8648, 0.0592, 0.3484, 0.0913],
[0.5808, 0.1909, 0.6165, 0.7088],
[0.5490, 0.4209, 0.9416, 0.2374],
[0.3652, 0.1218, 0.8805, 0.7523]])

>>> # delete and change property
>>> gt_instances = BaseDataElement(
... metainfo=dict(img_id=0, img_shape=(640, 640)),
... bboxes=torch.rand((6, 4)), scores=torch.rand((6,)))
>>> gt_instances.set_metainfo(dict(img_shape=(1280, 1280)))
>>> gt_instances.img_shape # (1280, 1280)
>>> gt_instances.bboxes = gt_instances.bboxes * 2
>>> gt_instances.get('img_shape', None) # (1280, 1280)
>>> gt_instances.get('bboxes', None) # 6x4 tensor
>>> del gt_instances.img_shape
>>> del gt_instances.bboxes
>>> assert 'img_shape' not in gt_instances
>>> assert 'bboxes' not in gt_instances
>>> gt_instances.pop('img_shape', None) # None
>>> gt_instances.pop('bboxes', None) # None

>>> # Tensor-like
>>> cuda_instances = gt_instances.cuda()
>>> cuda_instances = gt_instances.to('cuda:0')
>>> cpu_instances = cuda_instances.cpu()
>>> cpu_instances = cuda_instances.to('cpu')
>>> fp16_instances = cuda_instances.to(
... device=None, dtype=torch.float16, non_blocking=False,
... copy=False, memory_format=torch.preserve_format)
>>> cpu_instances = cuda_instances.detach()
>>> np_instances = cpu_instances.numpy()

>>> # print
>>> metainfo = dict(img_shape=(800, 1196, 3))
>>> gt_instances = BaseDataElement(
... metainfo=metainfo, det_labels=torch.LongTensor([0, 1, 2, 3]))
>>> sample = BaseDataElement(metainfo=metainfo,
... gt_instances=gt_instances)
>>> print(sample)
<BaseDataElement(
META INFORMATION
img_shape: (800, 1196, 3)
DATA FIELDS
gt_instances: <BaseDataElement(
META INFORMATION
img_shape: (800, 1196, 3)
DATA FIELDS
det_labels: tensor([0, 1, 2, 3])
) at 0x7f0ec5eadc70>
) at 0x7f0fea49e130>

>>> # inheritance
>>> class DetDataSample(BaseDataElement):
... @property
... def proposals(self):
... return self._proposals
... @proposals.setter
... def proposals(self, value):
... self.set_field(value, '_proposals', dtype=BaseDataElement)
... @proposals.deleter
... def proposals(self):
... del self._proposals
... @property
... def gt_instances(self):
... return self._gt_instances
... @gt_instances.setter
... def gt_instances(self, value):
... self.set_field(value, '_gt_instances',
... dtype=BaseDataElement)
... @gt_instances.deleter
... def gt_instances(self):
... del self._gt_instances
... @property
... def pred_instances(self):
... return self._pred_instances
... @pred_instances.setter
... def pred_instances(self, value):
... self.set_field(value, '_pred_instances',
... dtype=BaseDataElement)
... @pred_instances.deleter
... def pred_instances(self):
... del self._pred_instances
>>> det_sample = DetDataSample()
>>> proposals = BaseDataElement(bboxes=torch.rand((5, 4)))
>>> det_sample.proposals = proposals
>>> assert 'proposals' in det_sample
>>> assert det_sample.proposals == proposals
>>> del det_sample.proposals
>>> assert 'proposals' not in det_sample
>>> with self.assertRaises(AssertionError):
... det_sample.proposals = torch.rand((5, 4))
"""

def __init__(self, *, metainfo: Optional[dict] = None, **kwargs) -> None:
self._metainfo_fields: set = set()
self._data_fields: set = set()

if metainfo is not None:
self.set_metainfo(metainfo=metainfo)
if kwargs:
self.set_data(kwargs)

def set_metainfo(self, metainfo: dict) -> None:
"""Set or change key-value pairs in ``metainfo_field`` by parameter
``metainfo``.

Args:
metainfo (dict): A dict contains the meta information
of image, such as ``img_shape``, ``scale_factor``, etc.
"""
assert isinstance(
metainfo, dict
), f"metainfo should be a ``dict`` but got {type(metainfo)}"
meta = copy.deepcopy(metainfo)
for k, v in meta.items():
self.set_field(name=k, value=v, field_type="metainfo", dtype=None)

def set_data(self, data: dict) -> None:
"""Set or change key-value pairs in ``data_field`` by parameter
``data``.

Args:
data (dict): A dict contains annotations of image or
model predictions.
"""
assert isinstance(data, dict), f"data should be a `dict` but got {data}"
for k, v in data.items():
# Use `setattr()` rather than `self.set_field` to allow `set_data`
# to set property method.
setattr(self, k, v)

def update(self, instance: "BaseDataElement") -> None:
"""The update() method updates the BaseDataElement with the elements
from another BaseDataElement object.

Args:
instance (BaseDataElement): Another BaseDataElement object for
update the current object.
"""
assert isinstance(
instance, BaseDataElement
), f"instance should be a `BaseDataElement` but got {type(instance)}"
self.set_metainfo(dict(instance.metainfo_items()))
self.set_data(dict(instance.items()))

def new(self, *, metainfo: Optional[dict] = None, **kwargs) -> "BaseDataElement":
"""Return a new data element with same type. If ``metainfo`` and
``data`` are None, the new data element will have same metainfo and
data. If metainfo or data is not None, the new result will overwrite it
with the input value.

Args:
metainfo (dict, optional): A dict contains the meta information
of image, such as ``img_shape``, ``scale_factor``, etc.
Defaults to None.
kwargs (dict): A dict contains annotations of image or
model predictions.

Returns:
BaseDataElement: A new data element with same type.
"""
new_data = self.__class__()

if metainfo is not None:
new_data.set_metainfo(metainfo)
else:
new_data.set_metainfo(dict(self.metainfo_items()))
if kwargs:
new_data.set_data(kwargs)
else:
new_data.set_data(dict(self.items()))
return new_data

def clone(self):
"""Deep copy the current data element.

Returns:
BaseDataElement: The copy of current data element.
"""
clone_data = self.__class__()
clone_data.set_metainfo(dict(self.metainfo_items()))
clone_data.set_data(dict(self.items()))
return clone_data

def keys(self) -> list:
"""
Returns:
list: Contains all keys in data_fields.
"""
# We assume that the name of the attribute related to property is
# '_' + the name of the property. We use this rule to filter out
# private keys.
# TODO: Use a more robust way to solve this problem
private_keys = {
"_" + key
for key in self._data_fields
if isinstance(getattr(type(self), key, None), property)
}
return list(self._data_fields - private_keys)

def metainfo_keys(self) -> list:
"""
Returns:
list: Contains all keys in metainfo_fields.
"""
return list(self._metainfo_fields)

def values(self) -> list:
"""
Returns:
list: Contains all values in data.
"""
return [getattr(self, k) for k in self.keys()]

def metainfo_values(self) -> list:
"""
Returns:
list: Contains all values in metainfo.
"""
return [getattr(self, k) for k in self.metainfo_keys()]

def all_keys(self) -> list:
"""
Returns:
list: Contains all keys in metainfo and data.
"""
return self.metainfo_keys() + self.keys()

def all_values(self) -> list:
"""
Returns:
list: Contains all values in metainfo and data.
"""
return self.metainfo_values() + self.values()

def all_items(self) -> Iterator[Tuple[str, Any]]:
"""
Returns:
iterator: An iterator object whose element is (key, value) tuple
pairs for ``metainfo`` and ``data``.
"""
for k in self.all_keys():
yield (k, getattr(self, k))

def items(self) -> Iterator[Tuple[str, Any]]:
"""
Returns:
iterator: An iterator object whose element is (key, value) tuple
pairs for ``data``.
"""
for k in self.keys():
yield (k, getattr(self, k))

def metainfo_items(self) -> Iterator[Tuple[str, Any]]:
"""
Returns:
iterator: An iterator object whose element is (key, value) tuple
pairs for ``metainfo``.
"""
for k in self.metainfo_keys():
yield (k, getattr(self, k))

@property
def metainfo(self) -> dict:
"""dict: A dict contains metainfo of current data element."""
return dict(self.metainfo_items())

def __setattr__(self, name: str, value: Any):
"""setattr is only used to set data."""
if name in ("_metainfo_fields", "_data_fields"):
if not hasattr(self, name):
super().__setattr__(name, value)
else:
raise AttributeError(
f"{name} has been used as a "
"private attribute, which is immutable."
)
else:
self.set_field(name=name, value=value, field_type="data", dtype=None)

def __delattr__(self, item: str):
"""Delete the item in dataelement.

Args:
item (str): The key to delete.
"""
if item in ("_metainfo_fields", "_data_fields"):
raise AttributeError(
f"{item} has been used as a " "private attribute, which is immutable."
)
super().__delattr__(item)
if item in self._metainfo_fields:
self._metainfo_fields.remove(item)
elif item in self._data_fields:
self._data_fields.remove(item)

# dict-like methods
__delitem__ = __delattr__

def get(self, key, default=None) -> Any:
"""Get property in data and metainfo as the same as python."""
# Use `getattr()` rather than `self.__dict__.get()` to allow getting
# properties.
return getattr(self, key, default)

def pop(self, *args) -> Any:
"""Pop property in data and metainfo as the same as python."""
assert len(args) < 3, "``pop`` get more than 2 arguments"
name = args[0]
if name in self._metainfo_fields:
self._metainfo_fields.remove(args[0])
return self.__dict__.pop(*args)

elif name in self._data_fields:
self._data_fields.remove(args[0])
return self.__dict__.pop(*args)

# with default value
elif len(args) == 2:
return args[1]
else:
# don't just use 'self.__dict__.pop(*args)' for only popping key in
# metainfo or data
raise KeyError(f"{args[0]} is not contained in metainfo or data")

def __contains__(self, item: str) -> bool:
"""Whether the item is in dataelement.

Args:
item (str): The key to inquire.
"""
return item in self._data_fields or item in self._metainfo_fields

def set_field(
self,
value: Any,
name: str,
dtype: Optional[Union[Type, Tuple[Type, ...]]] = None,
field_type: str = "data",
) -> None:
"""Special method for set union field, used as property.setter
functions."""
assert field_type in ["metainfo", "data"]
if dtype is not None:
assert isinstance(
value, dtype
), f"{value} should be a {dtype} but got {type(value)}"

if field_type == "metainfo":
if name in self._data_fields:
raise AttributeError(
f"Cannot set {name} to be a field of metainfo "
f"because {name} is already a data field"
)
self._metainfo_fields.add(name)
else:
if name in self._metainfo_fields:
raise AttributeError(
f"Cannot set {name} to be a field of data "
f"because {name} is already a metainfo field"
)
self._data_fields.add(name)
super().__setattr__(name, value)

# Tensor-like methods
def to(self, *args, **kwargs) -> "BaseDataElement":
"""Apply same name function to all tensors in data_fields."""
new_data = self.new()
for k, v in self.items():
if hasattr(v, "to"):
v = v.to(*args, **kwargs)
data = {k: v}
new_data.set_data(data)
return new_data

# Tensor-like methods
def cpu(self) -> "BaseDataElement":
"""Convert all tensors to CPU in data."""
new_data = self.new()
for k, v in self.items():
if isinstance(v, (torch.Tensor, BaseDataElement)):
v = v.cpu()
data = {k: v}
new_data.set_data(data)
return new_data

# Tensor-like methods
def cuda(self) -> "BaseDataElement":
"""Convert all tensors to GPU in data."""
new_data = self.new()
for k, v in self.items():
if isinstance(v, (torch.Tensor, BaseDataElement)):
v = v.cuda()
data = {k: v}
new_data.set_data(data)
return new_data

# Tensor-like methods
def npu(self) -> "BaseDataElement":
"""Convert all tensors to NPU in data."""
new_data = self.new()
for k, v in self.items():
if isinstance(v, (torch.Tensor, BaseDataElement)):
v = v.npu()
data = {k: v}
new_data.set_data(data)
return new_data

def mlu(self) -> "BaseDataElement":
"""Convert all tensors to MLU in data."""
new_data = self.new()
for k, v in self.items():
if isinstance(v, (torch.Tensor, BaseDataElement)):
v = v.mlu()
data = {k: v}
new_data.set_data(data)
return new_data

# Tensor-like methods
def detach(self) -> "BaseDataElement":
"""Detach all tensors in data."""
new_data = self.new()
for k, v in self.items():
if isinstance(v, (torch.Tensor, BaseDataElement)):
v = v.detach()
data = {k: v}
new_data.set_data(data)
return new_data

# Tensor-like methods
def numpy(self) -> "BaseDataElement":
"""Convert all tensors to np.ndarray in data."""
new_data = self.new()
for k, v in self.items():
if isinstance(v, (torch.Tensor, BaseDataElement)):
v = v.detach().cpu().numpy()
data = {k: v}
new_data.set_data(data)
return new_data

def to_tensor(self) -> "BaseDataElement":
"""Convert all np.ndarray to tensor in data."""
new_data = self.new()
for k, v in self.items():
data = {}
if isinstance(v, np.ndarray):
v = torch.from_numpy(v)
data[k] = v
elif isinstance(v, BaseDataElement):
v = v.to_tensor()
data[k] = v
new_data.set_data(data)
return new_data

def to_dict(self) -> dict:
"""Convert BaseDataElement to dict."""
return {
k: v.to_dict() if isinstance(v, BaseDataElement) else v
for k, v in self.all_items()
}

def __repr__(self) -> str:
"""Represent the object."""

def _addindent(s_: str, num_spaces: int) -> str:
"""This func is modified from `pytorch` https://github.com/pytorch/
pytorch/blob/b17b2b1cc7b017c3daaeff8cc7ec0f514d42ec37/torch/nn/modu
les/module.py#L29.

Args:
s_ (str): The string to add spaces.
num_spaces (int): The num of space to add.

Returns:
str: The string after add indent.
"""
s = s_.split("\n")
# don't do anything for single-line stuff
if len(s) == 1:
return s_
first = s.pop(0)
s = [(num_spaces * " ") + line for line in s]
s = "\n".join(s) # type: ignore
s = first + "\n" + s # type: ignore
return s # type: ignore

def dump(obj: Any) -> str:
"""Represent the object.

Args:
obj (Any): The obj to represent.

Returns:
str: The represented str.
"""
_repr = ""
if isinstance(obj, dict):
for k, v in obj.items():
_repr += f"\n{k}: {_addindent(dump(v), 4)}"
elif isinstance(obj, BaseDataElement):
_repr += "\n\n META INFORMATION"
metainfo_items = dict(obj.metainfo_items())
_repr += _addindent(dump(metainfo_items), 4)
_repr += "\n\n DATA FIELDS"
items = dict(obj.items())
_repr += _addindent(dump(items), 4)
classname = obj.__class__.__name__
_repr = f"<{classname}({_repr}\n) at {hex(id(obj))}>"
else:
_repr += repr(obj)
return _repr

return dump(self)

+ 305
- 0
abl/structures/list_data.py View File

@@ -0,0 +1,305 @@
# Copyright (c) OpenMMLab. All rights reserved.
import itertools
from collections.abc import Sized
from typing import Any, List, Union

import numpy as np
import torch

from ..utils import flatten as flatten_list
from ..utils import to_hashable
from .base_data_element import BaseDataElement

BoolTypeTensor = Union[torch.BoolTensor, torch.cuda.BoolTensor]
LongTypeTensor = Union[torch.LongTensor, torch.cuda.LongTensor]

IndexType = Union[str, slice, int, list, LongTypeTensor, BoolTypeTensor, np.ndarray]


# Modified from
# https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/data_structures/instance_data.py # noqa
class ListData(BaseDataElement):
"""Data structure for instance-level annotations or predictions.

Subclass of :class:`BaseDataElement`. All value in `data_fields`
should have the same length. This design refer to
https://github.com/facebookresearch/detectron2/blob/master/detectron2/structures/instances.py # noqa E501
ListData also support extra functions: ``index``, ``slice`` and ``cat`` for data field. The type of value
in data field can be base data structure such as `torch.Tensor`, `numpy.ndarray`, `list`, `str`, `tuple`,
and can be customized data structure that has ``__len__``, ``__getitem__`` and ``cat`` attributes.

Examples:
>>> # custom data structure
>>> class TmpObject:
... def __init__(self, tmp) -> None:
... assert isinstance(tmp, list)
... self.tmp = tmp
... def __len__(self):
... return len(self.tmp)
... def __getitem__(self, item):
... if isinstance(item, int):
... if item >= len(self) or item < -len(self): # type:ignore
... raise IndexError(f'Index {item} out of range!')
... else:
... # keep the dimension
... item = slice(item, None, len(self))
... return TmpObject(self.tmp[item])
... @staticmethod
... def cat(tmp_objs):
... assert all(isinstance(results, TmpObject) for results in tmp_objs)
... if len(tmp_objs) == 1:
... return tmp_objs[0]
... tmp_list = [tmp_obj.tmp for tmp_obj in tmp_objs]
... tmp_list = list(itertools.chain(*tmp_list))
... new_data = TmpObject(tmp_list)
... return new_data
... def __repr__(self):
... return str(self.tmp)
>>> from mmengine.structures import ListData
>>> import numpy as np
>>> import torch
>>> img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3))
>>> instance_data = ListData(metainfo=img_meta)
>>> 'img_shape' in instance_data
True
>>> instance_data.det_labels = torch.LongTensor([2, 3])
>>> instance_data["det_scores"] = torch.Tensor([0.8, 0.7])
>>> instance_data.bboxes = torch.rand((2, 4))
>>> instance_data.polygons = TmpObject([[1, 2, 3, 4], [5, 6, 7, 8]])
>>> len(instance_data)
2
>>> print(instance_data)
<ListData(
META INFORMATION
img_shape: (800, 1196, 3)
pad_shape: (800, 1216, 3)
DATA FIELDS
det_labels: tensor([2, 3])
det_scores: tensor([0.8000, 0.7000])
bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188],
[0.8101, 0.3105, 0.5123, 0.6263]])
polygons: [[1, 2, 3, 4], [5, 6, 7, 8]]
) at 0x7fb492de6280>
>>> sorted_results = instance_data[instance_data.det_scores.sort().indices]
>>> sorted_results.det_scores
tensor([0.7000, 0.8000])
>>> print(instance_data[instance_data.det_scores > 0.75])
<ListData(
META INFORMATION
img_shape: (800, 1196, 3)
pad_shape: (800, 1216, 3)
DATA FIELDS
det_labels: tensor([2])
det_scores: tensor([0.8000])
bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188]])
polygons: [[1, 2, 3, 4]]
) at 0x7f64ecf0ec40>
>>> print(instance_data[instance_data.det_scores > 1])
<ListData(
META INFORMATION
img_shape: (800, 1196, 3)
pad_shape: (800, 1216, 3)
DATA FIELDS
det_labels: tensor([], dtype=torch.int64)
det_scores: tensor([])
bboxes: tensor([], size=(0, 4))
polygons: []
) at 0x7f660a6a7f70>
>>> print(instance_data.cat([instance_data, instance_data]))
<ListData(
META INFORMATION
img_shape: (800, 1196, 3)
pad_shape: (800, 1216, 3)
DATA FIELDS
det_labels: tensor([2, 3, 2, 3])
det_scores: tensor([0.8000, 0.7000, 0.8000, 0.7000])
bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188],
[0.8101, 0.3105, 0.5123, 0.6263],
[0.4997, 0.7707, 0.0595, 0.4188],
[0.8101, 0.3105, 0.5123, 0.6263]])
polygons: [[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4], [5, 6, 7, 8]]
) at 0x7f203542feb0>
"""

def __setattr__(self, name: str, value: list):
"""setattr is only used to set data.

The value must have the attribute of `__len__` and have the same length
of `ListData`.
"""
if name in ("_metainfo_fields", "_data_fields"):
if not hasattr(self, name):
super().__setattr__(name, value)
else:
raise AttributeError(
f"{name} has been used as a " "private attribute, which is immutable."
)

else:
# assert isinstance(value, list), "value must be of type `list`"

# if len(self) > 0:
# assert len(value) == len(self), (
# "The length of "
# f"values {len(value)} is "
# "not consistent with "
# "the length of this "
# ":obj:`ListData` "
# f"{len(self)}"
# )
super().__setattr__(name, value)

__setitem__ = __setattr__

def __getitem__(self, item: IndexType) -> "ListData":
"""
Args:
item (str, int, list, :obj:`slice`, :obj:`numpy.ndarray`,
:obj:`torch.LongTensor`, :obj:`torch.BoolTensor`):
Get the corresponding values according to item.

Returns:
:obj:`ListData`: Corresponding values.
"""
assert isinstance(item, IndexType.__args__)
if isinstance(item, list):
item = np.array(item)
if isinstance(item, np.ndarray):
# The default int type of numpy is platform dependent, int32 for
# windows and int64 for linux. `torch.Tensor` requires the index
# should be int64, therefore we simply convert it to int64 here.
# More details in https://github.com/numpy/numpy/issues/9464
item = item.astype(np.int64) if item.dtype == np.int32 else item
item = torch.from_numpy(item)

if isinstance(item, str):
return getattr(self, item)

new_data = self.__class__(metainfo=self.metainfo)

if isinstance(item, torch.Tensor):
assert item.dim() == 1, "Only support to get the" " values along the first dimension."

for k, v in self.items():
if v is None:
new_data[k] = None
elif isinstance(v, torch.Tensor):
new_data[k] = v[item]
elif isinstance(v, np.ndarray):
new_data[k] = v[item.cpu().numpy()]
elif isinstance(v, (str, list, tuple)) or (
hasattr(v, "__getitem__") and hasattr(v, "cat")
):
# convert to indexes from BoolTensor
if isinstance(item, BoolTypeTensor.__args__):
indexes = torch.nonzero(item).view(-1).cpu().numpy().tolist()
else:
indexes = item.cpu().numpy().tolist()
slice_list = []
if indexes:
for index in indexes:
slice_list.append(slice(index, None, len(v)))
else:
slice_list.append(slice(None, 0, None))
r_list = [v[s] for s in slice_list]
if isinstance(v, (str, list, tuple)):
new_value = r_list[0]
for r in r_list[1:]:
new_value = new_value + r
else:
new_value = v.cat(r_list)
new_data[k] = new_value
else:
raise ValueError(
f"The type of `{k}` is `{type(v)}`, which has no "
"attribute of `cat`, so it does not "
"support slice with `bool`"
)

else:
# item is a slice or int
for k, v in self.items():
if v is None:
new_data[k] = None
else:
new_data[k] = v[item]
return new_data # type:ignore

@staticmethod
def cat(instances_list: List["ListData"]) -> "ListData":
"""Concat the instances of all :obj:`ListData` in the list.

Note: To ensure that cat returns as expected, make sure that
all elements in the list must have exactly the same keys.

Args:
instances_list (list[:obj:`ListData`]): A list
of :obj:`ListData`.

Returns:
:obj:`ListData`
"""
assert all(isinstance(results, ListData) for results in instances_list)
assert len(instances_list) > 0
if len(instances_list) == 1:
return instances_list[0]

# metainfo and data_fields must be exactly the
# same for each element to avoid exceptions.
field_keys_list = [instances.all_keys() for instances in instances_list]
assert len({len(field_keys) for field_keys in field_keys_list}) == 1 and len(
set(itertools.chain(*field_keys_list))
) == len(field_keys_list[0]), (
"There are different keys in "
"`instances_list`, which may "
"cause the cat operation "
"to fail. Please make sure all "
"elements in `instances_list` "
"have the exact same key."
)

new_data = instances_list[0].__class__(metainfo=instances_list[0].metainfo)
for k in instances_list[0].keys():
values = [results[k] for results in instances_list]
v0 = values[0]
if isinstance(v0, torch.Tensor):
new_values = torch.cat(values, dim=0)
elif isinstance(v0, np.ndarray):
new_values = np.concatenate(values, axis=0)
elif isinstance(v0, (str, list, tuple)):
new_values = v0[:]
for v in values[1:]:
new_values += v
elif hasattr(v0, "cat"):
new_values = v0.cat(values)
else:
raise ValueError(
f"The type of `{k}` is `{type(v0)}` which has no " "attribute of `cat`"
)
new_data[k] = new_values
return new_data # type:ignore

def flatten(self, item: IndexType) -> List:
"""Flatten self[item].

Returns:
list: Flattened data fields.
"""
return flatten_list(self[item])

def elements_num(self, item: IndexType) -> int:
"""int: The number of elements in self[item]."""
return len(self.flatten(item))

def to_tuple(self, item: IndexType) -> tuple:
"""tuple: The data fields in self[item] converted to tuple."""
return to_hashable(self[item])

def __len__(self) -> int:
"""int: The length of ListData."""
if len(self._data_fields) > 0:
one_element = next(iter(self._data_fields))
return len(getattr(self, one_element))
# return len(self.values()[0])
else:
return 0

+ 2
- 1
abl/utils/__init__.py View File

@@ -1,2 +1,3 @@
from .cache import Cache, abl_cache
from .logger import ABLLogger, print_log
from .utils import *
from .utils import *

+ 104
- 0
abl/utils/cache.py View File

@@ -0,0 +1,104 @@
import pickle
import os
import os.path as osp
from typing import Callable, Generic, TypeVar

from .logger import print_log, ABLLogger

K = TypeVar("K")
T = TypeVar("T")
PREV, NEXT, KEY, RESULT = 0, 1, 2, 3 # names for the link fields


class Cache(Generic[K, T]):
def __init__(self, func: Callable[[K], T]):
"""Create cache

:param func: Function this cache evaluates
:param cache: If true, do in memory caching.
:param cache_root: If not None, cache to files at the provided path.
:param key_func: Convert the key into a hashable object if needed
"""
self.func = func
self.has_init = False

def __getitem__(self, obj, *args) -> T:
return self.get_from_dict(obj, *args)

def clear_cache(self):
"""Invalidate entire cache."""
self.cache_dict.clear()

def _init_cache(self, obj):
if self.has_init:
return

self.cache = True
self.cache_dict = dict()
self.key_func = obj.key_func
self.max_size = obj.max_cache_size

self.hits, self.misses = 0, 0
self.full = False
self.root = [] # root of the circular doubly linked list
self.root[:] = [self.root, self.root, None, None]

self.has_init = True

def get_from_dict(self, obj, *args) -> T:
"""Implements dict based cache."""
pred_pseudo_label, y, *res_args = args
cache_key = (self.key_func(pred_pseudo_label), self.key_func(y), *res_args)
link = self.cache_dict.get(cache_key)
if link is not None:
# Move the link to the front of the circular queue
link_prev, link_next, _key, result = link
link_prev[NEXT] = link_next
link_next[PREV] = link_prev
last = self.root[PREV]
last[NEXT] = self.root[PREV] = link
link[PREV] = last
link[NEXT] = self.root
self.hits += 1
return result
self.misses += 1

result = self.func(obj, *args)

if self.full:
# Use the old root to store the new key and result.
oldroot = self.root
oldroot[KEY] = cache_key
oldroot[RESULT] = result
# Empty the oldest link and make it the new root.
self.root = oldroot[NEXT]
oldkey = self.root[KEY]
oldresult = self.root[RESULT]
self.root[KEY] = self.root[RESULT] = None
# Now update the cache dictionary.
del self.cache_dict[oldkey]
self.cache_dict[cache_key] = oldroot
else:
# Put result in a new link at the front of the queue.
last = self.root[PREV]
link = [last, self.root, cache_key, result]
last[NEXT] = self.root[PREV] = self.cache_dict[cache_key] = link
if isinstance(self.max_size, int):
self.full = len(self.cache_dict) >= self.max_size
return result


def abl_cache():
def decorator(func):
cache_instance = Cache(func)

def wrapper(obj, *args):
if obj.use_cache:
cache_instance._init_cache(obj)
return cache_instance.get_from_dict(obj, *args)
else:
return func(obj, *args)

return wrapper

return decorator

+ 57
- 7
abl/utils/utils.py View File

@@ -1,6 +1,7 @@
import numpy as np
from itertools import chain

import numpy as np


def flatten(nested_list):
"""
@@ -15,6 +16,11 @@ def flatten(nested_list):
-------
list
A flattened version of the input list.

Raises
------
TypeError
If the input object is not a list.
"""
if not isinstance(nested_list, list):
raise TypeError("Input must be of type list.")
@@ -25,7 +31,7 @@ def flatten(nested_list):
return list(chain.from_iterable(nested_list))


def reform_idx(flattened_list, structured_list):
def reform_list(flattened_list, structured_list):
"""
Reform the index based on structured_list structure.

@@ -41,6 +47,9 @@ def reform_idx(flattened_list, structured_list):
list
A reformed list that mimics the structure of structured_list.
"""
# if not isinstance(flattened_list, list):
# raise TypeError("Input must be of type list.")

if not isinstance(structured_list[0], (list, tuple)):
return flattened_list

@@ -80,7 +89,7 @@ def hamming_dist(pred_pseudo_label, candidates):
return np.sum(pred_pseudo_label != candidates, axis=1)


def confidence_dist(pred_prob, candidates_idx):
def confidence_dist(pred_prob, candidates):
"""
Compute the confidence distance between prediction probabilities and candidates.

@@ -89,7 +98,7 @@ def confidence_dist(pred_prob, candidates_idx):
pred_prob : list of numpy.ndarray
Prediction probability distributions, each element is an ndarray
representing the probability distribution of a particular prediction.
candidates_idx : list of list of int
candidates : list of list of int
Index of candidate labels, each element is a list of indexes being considered
as a candidate correction.

@@ -99,8 +108,8 @@ def confidence_dist(pred_prob, candidates_idx):
Confidence distances computed for each candidate.
"""
pred_prob = np.clip(pred_prob, 1e-9, 1)
_, cols = np.indices((len(candidates_idx), len(candidates_idx[0])))
return 1 - np.prod(pred_prob[cols, candidates_idx], axis=1)
_, cols = np.indices((len(candidates), len(candidates[0])))
return 1 - np.prod(pred_prob[cols, candidates], axis=1)


def block_sample(X, Z, Y, sample_num, seg_idx):
@@ -154,6 +163,7 @@ def to_hashable(x):
return tuple(to_hashable(item) for item in x)
return x


def restore_from_hashable(x):
"""
Convert a nested tuple back to a nested list.
@@ -170,10 +180,49 @@ def restore_from_hashable(x):
otherwise the original input.
"""
if isinstance(x, tuple):
return [hashable_to_list(item) for item in x]
return [restore_from_hashable(item) for item in x]
return x


def calculate_revision_num(parameter, total_length):
"""
Convert a float parameter to an integer, based on a total length.

Parameters
----------
parameter : int or float
The parameter to convert. If float, it should be between 0 and 1.
If int, it should be non-negative. If -1, it will be replaced with total_length.
total_length : int
The total length to calculate the parameter from if it's a fraction.

Returns
-------
int
The calculated parameter.

Raises
------
TypeError
If parameter is not an int or a float.
ValueError
If parameter is a float not in [0, 1] or an int below 0.
"""
if not isinstance(parameter, (int, float)):
raise TypeError("Parameter must be of type int or float.")

if parameter == -1:
return total_length
elif isinstance(parameter, float):
if not (0 <= parameter <= 1):
raise ValueError("If parameter is a float, it must be between 0 and 1.")
return round(total_length * parameter)
else:
if parameter < 0:
raise ValueError("If parameter is an int, it must be non-negative.")
return parameter


if __name__ == "__main__":
A = np.array(
[
@@ -227,4 +276,5 @@ if __name__ == "__main__":
)
B = [[0, 9, 3], [0, 11, 4]]

print(ori_confidence_dist(A, B))
print(confidence_dist(A, B))

+ 12
- 12
examples/hed/hed_bridge.py View File

@@ -19,17 +19,17 @@ class HEDBridge(SimpleBridge):
def __init__(
self,
model: ABLModel,
abducer: ReasonerBase,
reasoner: ReasonerBase,
metric_list: BaseMetric,
) -> None:
super().__init__(model, abducer, metric_list)
super().__init__(model, reasoner, metric_list)

def pretrain(self, weights_dir):
if not os.path.exists(os.path.join(weights_dir, "pretrain_weights.pth")):
print_log("Pretrain Start", logger="current")

cls_autoencoder = SymbolNetAutoencoder(
num_classes=len(self.abducer.kb.pseudo_label_list)
num_classes=len(self.reasoner.kb.pseudo_label_list)
)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
criterion = torch.nn.MSELoss()
@@ -74,7 +74,7 @@ class HEDBridge(SimpleBridge):
max_revision=-1,
require_more_revision=0,
):
return self.abducer.abduce(
return self.reasoner.abduce(
(pred_label, pred_prob, pseudo_label, Y),
max_revision,
require_more_revision,
@@ -86,8 +86,8 @@ class HEDBridge(SimpleBridge):
pred_pseudo_label_list = []
abduced_pseudo_label_list = []
for _mapping in candidate_mappings:
self.abducer.mapping = _mapping
self.abducer.set_remapping()
self.reasoner.mapping = _mapping
self.reasoner.set_remapping()
pred_pseudo_label = self.label_to_pseudo_label(pred_label)
abduced_pseudo_label = self.abduce_pseudo_label(
pred_label, pred_prob, pred_pseudo_label, Y, 20
@@ -100,8 +100,8 @@ class HEDBridge(SimpleBridge):

max_revisible_instances = max(mapping_score)
return_idx = mapping_score.index(max_revisible_instances)
self.abducer.mapping = candidate_mappings[return_idx]
self.abducer.set_remapping()
self.reasoner.mapping = candidate_mappings[return_idx]
self.reasoner.set_remapping()
return abduced_pseudo_label_list[return_idx]

def check_training_impact(self, filtered_X, filtered_abduced_label, X):
@@ -137,7 +137,7 @@ class HEDBridge(SimpleBridge):
pred_pseudo_label = self.label_to_pseudo_label(pred_label)
consistent_num = sum(
[
self.abducer.kb.consist_rule(instance, rule)
self.reasoner.kb.consist_rule(instance, rule)
for instance in pred_pseudo_label
]
)
@@ -159,11 +159,11 @@ class HEDBridge(SimpleBridge):
pred_pseudo_label = self.label_to_pseudo_label(pred_label)
consistent_instance = []
for instance in pred_pseudo_label:
if self.abducer.kb.logic_forward([instance]):
if self.reasoner.kb.logic_forward([instance]):
consistent_instance.append(instance)

if len(consistent_instance) != 0:
rule = self.abducer.abduce_rules(consistent_instance)
rule = self.reasoner.abduce_rules(consistent_instance)
if rule != None:
rules.append(rule)
break
@@ -280,7 +280,7 @@ class HEDBridge(SimpleBridge):
else:
if equation_len == min_len:
print_log(
"Learned mapping is: " + str(self.abducer.mapping),
"Learned mapping is: " + str(self.reasoner.mapping),
logger="current",
)
self.model.load(load_path="./weights/pretrain_weights.pth")


+ 142
- 27
examples/hwf/hwf_example.ipynb View File

@@ -2,10 +2,14 @@
"cells": [
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"\n",
"sys.setrecursionlimit(10000)\n",
"\n",
"import torch\n",
"import numpy as np\n",
"import torch.nn as nn\n",
@@ -23,9 +27,17 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"11/16 20:43:38 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Abductive Learning on the HWF example.\n"
]
}
],
"source": [
"# Initialize logger and print basic information\n",
"print_log(\"Abductive Learning on the HWF example.\", logger=\"current\")\n",
@@ -45,21 +57,12 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# Initialize knowledge base and abducer\n",
"# Initialize knowledge base and reasoner\n",
"class HWF_KB(KBBase):\n",
" def __init__(\n",
" self, \n",
" pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], \n",
" prebuild_GKB=False,\n",
" GKB_len_list=[1, 3, 5, 7],\n",
" max_err=1e-3,\n",
" use_cache=True\n",
" ):\n",
" super().__init__(pseudo_label_list, prebuild_GKB, GKB_len_list, max_err, use_cache)\n",
"\n",
" def _valid_candidate(self, formula):\n",
" if len(formula) % 2 == 0:\n",
@@ -79,8 +82,8 @@
" formula = [mapping[f] for f in formula]\n",
" return eval(''.join(formula))\n",
"\n",
"kb = HWF_KB(prebuild_GKB=True)\n",
"abducer = ReasonerBase(kb, dist_func='confidence')"
"kb = HWF_KB(pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], max_err=1e-10, use_cache=False)\n",
"reasoner = ReasonerBase(kb, dist_func='confidence')"
]
},
{
@@ -93,7 +96,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
@@ -106,7 +109,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
@@ -126,7 +129,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
@@ -146,12 +149,12 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# Add metric\n",
"metric_list = [SymbolMetric(prefix=\"hwf\"), SemanticsMetric(prefix=\"hwf\")]"
"metric_list = [SymbolMetric(prefix=\"hwf\"), SemanticsMetric(kb=kb, prefix=\"hwf\")]"
]
},
{
@@ -164,7 +167,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
@@ -183,11 +186,11 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"bridge = SimpleBridge(model=model, abducer=abducer, metric_list=metric_list)"
"bridge = SimpleBridge(model=model, reasoner=reasoner, metric_list=metric_list)"
]
},
{
@@ -200,11 +203,123 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 10,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"11/16 20:44:02 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n",
"11/16 20:44:02 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n",
"11/16 20:44:02 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_3.pth\n",
"11/16 20:44:02 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [1/3] segment(train) [1/10] model loss is 0.16911\n",
"11/16 20:44:03 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n",
"11/16 20:44:03 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n",
"11/16 20:44:03 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_3.pth\n",
"11/16 20:44:03 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [1/3] segment(train) [2/10] model loss is 0.17734\n",
"11/16 20:44:03 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n",
"11/16 20:44:03 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n",
"11/16 20:44:04 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_3.pth\n",
"11/16 20:44:04 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [1/3] segment(train) [3/10] model loss is 0.01907\n",
"11/16 20:44:04 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n",
"11/16 20:44:04 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n",
"11/16 20:44:04 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_3.pth\n",
"11/16 20:44:04 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [1/3] segment(train) [4/10] model loss is 0.01403\n",
"11/16 20:44:05 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n",
"11/16 20:44:05 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n",
"11/16 20:44:05 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_3.pth\n",
"11/16 20:44:05 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [1/3] segment(train) [5/10] model loss is 0.00509\n",
"11/16 20:44:06 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n",
"11/16 20:44:06 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n",
"11/16 20:44:06 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_3.pth\n",
"11/16 20:44:06 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [1/3] segment(train) [6/10] model loss is 0.00713\n",
"11/16 20:44:06 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n",
"11/16 20:44:07 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n",
"11/16 20:44:07 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_3.pth\n",
"11/16 20:44:07 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [1/3] segment(train) [7/10] model loss is 0.00455\n",
"11/16 20:44:07 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n",
"11/16 20:44:07 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n",
"11/16 20:44:08 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_3.pth\n",
"11/16 20:44:08 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [1/3] segment(train) [8/10] model loss is 0.00946\n",
"11/16 20:44:08 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n",
"11/16 20:44:08 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n",
"11/16 20:44:08 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [1/3] segment(train) [9/10] model loss is 0.00957\n",
"11/16 20:44:09 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n",
"11/16 20:44:09 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n",
"11/16 20:44:09 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_3.pth\n",
"11/16 20:44:09 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [1/3] segment(train) [10/10] model loss is 0.00323\n",
"11/16 20:44:09 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation start: loop(val) [1]\n",
"11/16 20:44:10 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation ended, hwf/character_accuracy: 0.997 hwf/semantics_accuracy: 0.985 \n",
"11/16 20:44:10 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Saving model: loop(save) [1]\n",
"11/16 20:44:10 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_loop_1.pth\n",
"11/16 20:44:10 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n",
"11/16 20:44:10 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n",
"11/16 20:44:10 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [2/3] segment(train) [1/10] model loss is 0.00666\n",
"11/16 20:44:10 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n",
"11/16 20:44:11 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n",
"11/16 20:44:11 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_3.pth\n",
"11/16 20:44:11 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [2/3] segment(train) [2/10] model loss is 0.01438\n",
"11/16 20:44:11 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n",
"11/16 20:44:11 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n",
"11/16 20:44:11 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [2/3] segment(train) [3/10] model loss is 0.00450\n",
"11/16 20:44:11 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n",
"11/16 20:44:12 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n",
"11/16 20:44:12 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [2/3] segment(train) [4/10] model loss is 0.00764\n",
"11/16 20:44:12 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n",
"11/16 20:44:12 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n",
"11/16 20:44:12 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [2/3] segment(train) [5/10] model loss is 0.00644\n",
"11/16 20:44:13 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n",
"11/16 20:44:13 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n",
"11/16 20:44:13 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [2/3] segment(train) [6/10] model loss is 0.00189\n",
"11/16 20:44:13 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n",
"11/16 20:44:13 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n",
"11/16 20:44:13 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [2/3] segment(train) [7/10] model loss is 0.00397\n",
"11/16 20:44:14 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n",
"11/16 20:44:14 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [2/3] segment(train) [8/10] model loss is 0.00936\n",
"11/16 20:44:14 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n",
"11/16 20:44:14 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n",
"11/16 20:44:14 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [2/3] segment(train) [9/10] model loss is 0.00960\n",
"11/16 20:44:15 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n",
"11/16 20:44:15 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n",
"11/16 20:44:15 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [2/3] segment(train) [10/10] model loss is 0.00572\n",
"11/16 20:44:15 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation start: loop(val) [2]\n",
"11/16 20:44:16 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation ended, hwf/character_accuracy: 0.999 hwf/semantics_accuracy: 0.995 \n",
"11/16 20:44:16 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Saving model: loop(save) [2]\n",
"11/16 20:44:16 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_loop_2.pth\n",
"11/16 20:44:16 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n",
"11/16 20:44:16 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n",
"11/16 20:44:16 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [3/3] segment(train) [1/10] model loss is 0.00180\n",
"11/16 20:44:16 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n",
"11/16 20:44:16 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n",
"11/16 20:44:16 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [3/3] segment(train) [2/10] model loss is 0.00615\n",
"11/16 20:44:16 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n",
"11/16 20:44:16 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [3/3] segment(train) [3/10] model loss is 0.01000\n",
"11/16 20:44:17 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n",
"11/16 20:44:17 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_2.pth\n",
"11/16 20:44:17 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [3/3] segment(train) [4/10] model loss is 0.00415\n",
"11/16 20:44:17 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n",
"11/16 20:44:17 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [3/3] segment(train) [5/10] model loss is 0.00960\n",
"11/16 20:44:17 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n",
"11/16 20:44:17 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [3/3] segment(train) [6/10] model loss is 0.00697\n",
"11/16 20:44:18 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n",
"11/16 20:44:18 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [3/3] segment(train) [7/10] model loss is 0.00977\n",
"11/16 20:44:18 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n",
"11/16 20:44:18 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [3/3] segment(train) [8/10] model loss is 0.00734\n",
"11/16 20:44:18 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n",
"11/16 20:44:18 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [3/3] segment(train) [9/10] model loss is 0.00922\n",
"11/16 20:44:19 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_epoch_1.pth\n",
"11/16 20:44:19 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [3/3] segment(train) [10/10] model loss is 0.00982\n",
"11/16 20:44:19 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation start: loop(val) [3]\n",
"11/16 20:44:20 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation ended, hwf/character_accuracy: 0.998 hwf/semantics_accuracy: 0.986 \n",
"11/16 20:44:20 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Saving model: loop(save) [3]\n",
"11/16 20:44:20 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231116_20_43_38/weights/model_checkpoint_loop_3.pth\n",
"11/16 20:44:20 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation ended, hwf/character_accuracy: 0.994 hwf/semantics_accuracy: 0.970 \n"
]
}
],
"source": [
"bridge.train(train_data, epochs=3, batch_size=1000)\n",
"bridge.train(train_data, loops=3, segment_size=1000, save_interval=1, save_dir=weights_dir)\n",
"bridge.test(test_data)"
]
}


+ 25
- 24
examples/mnist_add/mnist_add_example.ipynb View File

@@ -2,10 +2,12 @@
"cells": [
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os.path as osp\n",
"\n",
"import torch.nn as nn\n",
"import torch\n",
"\n",
@@ -13,21 +15,25 @@
"\n",
"from abl.learning import BasicNN, ABLModel\n",
"from abl.bridge import SimpleBridge\n",
"from abl.evaluation import SymbolMetric, ABLMetric\n",
"from abl.utils import ABLLogger\n",
"from abl.evaluation import SymbolMetric, SemanticsMetric\n",
"from abl.utils import ABLLogger, print_log\n",
"\n",
"from models.nn import LeNet5\n",
"from examples.models.nn import LeNet5\n",
"from examples.mnist_add.datasets.get_mnist_add import get_mnist_add"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Initialize logger\n",
"logger = ABLLogger.get_instance(\"abl\")"
"print_log(\"Abductive Learning on the MNIST Add example.\", logger=\"current\")\n",
"\n",
"# Retrieve the directory of the Log file and define the directory for saving the model weights.\n",
"log_dir = ABLLogger.get_current_instance().log_dir\n",
"weights_dir = osp.join(log_dir, \"weights\")"
]
},
{
@@ -40,22 +46,19 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Initialize knowledge base and abducer\n",
"# Initialize knowledge base and reasoner\n",
"class add_KB(KBBase):\n",
" def __init__(self, pseudo_label_list=list(range(10)), prebuild_GKB=False, GKB_len_list=[2], max_err=0, use_cache=True):\n",
" super().__init__(pseudo_label_list, prebuild_GKB, GKB_len_list, max_err, use_cache)\n",
"\n",
" def logic_forward(self, nums):\n",
" return sum(nums)\n",
"\n",
"kb = add_KB(prebuild_GKB=True)\n",
"kb = add_KB(pseudo_label_list=list(range(10)))\n",
"\n",
"# kb = prolog_KB(pseudo_label_list=list(range(10)), pl_file='datasets/mnist_add/add.pl')\n",
"abducer = ReasonerBase(kb, dist_func=\"confidence\")"
"reasoner = ReasonerBase(kb, dist_func=\"confidence\")"
]
},
{
@@ -68,7 +71,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -81,7 +84,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -92,8 +95,6 @@
" criterion,\n",
" optimizer,\n",
" device,\n",
" save_interval=1,\n",
" save_dir=logger.save_dir,\n",
" batch_size=32,\n",
" num_epochs=1,\n",
")"
@@ -109,7 +110,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -129,12 +130,12 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Add metric\n",
"metric = [SymbolMetric(prefix=\"mnist_add\"), ABLMetric(prefix=\"mnist_add\")]"
"metric = [SymbolMetric(prefix=\"mnist_add\"), SemanticsMetric(kb=kb, prefix=\"mnist_add\")]"
]
},
{
@@ -147,7 +148,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -170,7 +171,7 @@
"metadata": {},
"outputs": [],
"source": [
"bridge = SimpleBridge(model, abducer, metric)"
"bridge = SimpleBridge(model, reasoner, metric)"
]
},
{
@@ -187,7 +188,7 @@
"metadata": {},
"outputs": [],
"source": [
"bridge.train(train_data, epochs=5, batch_size=10000)\n",
"bridge.train(train_data, loops=5, segment_size=10000, save_interval=1, save_dir=weights_dir)\n",
"bridge.test(test_data)"
]
}
@@ -208,7 +209,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
"version": "3.8.16"
},
"orig_nbformat": 4,
"vscode": {


Loading…
Cancel
Save