|
|
|
@@ -1,83 +1,105 @@ |
|
|
|
Quick Start |
|
|
|
================== |
|
|
|
=========== |
|
|
|
|
|
|
|
We use the MNIST Add benchmark as a quick start example. |
|
|
|
We use the MNIST Add benchmark as a quick start example. In this task, the inputs are |
|
|
|
pairs of MNIST handwritten images, and the outputs are their sums. |
|
|
|
To complete this task, we first process the images through a machine learning model |
|
|
|
to get their corresponding pseudo labels (the number each image represents). |
|
|
|
Then, the recognized labels undergo logical reasoning which calculates their sum. |
|
|
|
|
|
|
|
Load Data |
|
|
|
--------- |
|
|
|
|
|
|
|
ABL-Package assumes data to be in the form of ``(X, gt_pseudo_label, Y)`` |
|
|
|
where ``X`` is the input of the machine learning model, |
|
|
|
``Y`` is the ground truth of the reasoning result and |
|
|
|
``gt_pseudo_label`` is the ground truth label of each element in ``X``. |
|
|
|
|
|
|
|
.. code:: python |
|
|
|
|
|
|
|
from examples.mnist_add.datasets.get_mnist_add import get_mnist_add |
|
|
|
|
|
|
|
train_data = get_mnist_add(train=True, get_pseudo_label=True) |
|
|
|
test_data = get_mnist_add(train=False, get_pseudo_label=True) |
|
|
|
|
|
|
|
In the above ``get_mnist_add``, the return values are tuples of ``(X, gt_pseudo_label, Y)``. |
|
|
|
|
|
|
|
Machine Learning (Map input to pseudo labels) |
|
|
|
--------------------------------------------- |
|
|
|
|
|
|
|
We use a simple LeNet5 model to recognize the pseudo labels (numbers) in the images. |
|
|
|
We first build the model and define its corresponding criterion and optimizer for training. |
|
|
|
|
|
|
|
.. code:: python |
|
|
|
|
|
|
|
import torch |
|
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
from abl.bridge import SimpleBridge |
|
|
|
from abl.evaluation import SemanticsMetric, SymbolMetric |
|
|
|
from abl.learning import ABLModel, BasicNN |
|
|
|
from abl.reasoning import KBBase, ReasonerBase |
|
|
|
from abl.utils import print_log |
|
|
|
from examples.mnist_add.datasets.get_mnist_add import get_mnist_add |
|
|
|
from examples.models.nn import LeNet5 |
|
|
|
|
|
|
|
# Build logger |
|
|
|
print_log("Abductive Learning on the MNIST Add example.", logger="current") |
|
|
|
|
|
|
|
# Machine Learning Part |
|
|
|
# Build necessary components for BasicNN |
|
|
|
cls = LeNet5(num_classes=10) |
|
|
|
criterion = nn.CrossEntropyLoss() |
|
|
|
optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99)) |
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
# Build BasicNN |
|
|
|
|
|
|
|
Afterward, we wrap it in ``ABLModel``. |
|
|
|
|
|
|
|
.. code:: python |
|
|
|
|
|
|
|
from abl.learning import ABLModel, BasicNN |
|
|
|
|
|
|
|
base_model = BasicNN(cls, criterion, optimizer, device) |
|
|
|
# Build ABLModel |
|
|
|
model = ABLModel(base_model) |
|
|
|
|
|
|
|
# Logic Part |
|
|
|
# Build knowledge base and reasoner |
|
|
|
Reasoning (Map pseudo labels to reasoning results) |
|
|
|
-------------------------------------------------- |
|
|
|
|
|
|
|
First, we build a knowledge base that defines how to deduce |
|
|
|
logical results (i.e., calculate summation) from the pseudo labels |
|
|
|
obtained by machine learning. |
|
|
|
|
|
|
|
.. code:: python |
|
|
|
|
|
|
|
from abl.reasoning import KBBase, ReasonerBase |
|
|
|
|
|
|
|
class AddKB(KBBase): |
|
|
|
def __init__(self, pseudo_label_list): |
|
|
|
def __init__(self, pseudo_label_list=list(range(10))): |
|
|
|
super().__init__(pseudo_label_list) |
|
|
|
|
|
|
|
# Implement the deduction function |
|
|
|
def logic_forward(self, nums): |
|
|
|
return sum(nums) |
|
|
|
|
|
|
|
|
|
|
|
kb = AddKB(pseudo_label_list=list(range(10))) |
|
|
|
|
|
|
|
Then we define a reasoner, which defines |
|
|
|
how to minimize the inconsistency between the knowledge base and machine learning. |
|
|
|
|
|
|
|
.. code:: python |
|
|
|
|
|
|
|
reasoner = ReasonerBase(kb, dist_func="confidence") |
|
|
|
|
|
|
|
# Datasets and Evaluation Metrics |
|
|
|
# Get training and testing data |
|
|
|
train_data = get_mnist_add(train=True, get_pseudo_label=True) |
|
|
|
test_data = get_mnist_add(train=False, get_pseudo_label=True) |
|
|
|
# Set up metrics |
|
|
|
Bridge Machine Learning and Reasoning |
|
|
|
------------------------------------- |
|
|
|
|
|
|
|
First, we use `SimpleBridge` to combine machine learning and reasoning together, |
|
|
|
setting the stage for subsequent integrated training, validation, and testing. |
|
|
|
|
|
|
|
.. code:: python |
|
|
|
|
|
|
|
from abl.bridge import SimpleBridge |
|
|
|
|
|
|
|
|
|
|
|
Next, we define the metrics to measure accuracy during validation and testing. |
|
|
|
|
|
|
|
.. code:: python |
|
|
|
|
|
|
|
from abl.evaluation import SemanticsMetric, SymbolMetric |
|
|
|
|
|
|
|
metric_list = [SymbolMetric(prefix="mnist_add"), SemanticsMetric(kb=kb, prefix="mnist_add")] |
|
|
|
|
|
|
|
# Bridge Machine Learning and Logic Reasoning |
|
|
|
bridge = SimpleBridge(model, reasoner, metric_list) |
|
|
|
Finally, we proceed with testing and training. |
|
|
|
|
|
|
|
.. code:: python |
|
|
|
|
|
|
|
# Train and Test |
|
|
|
bridge.train(train_data, loops=5, segment_size=10000) |
|
|
|
bridge.test(test_data) |
|
|
|
|
|
|
|
|
|
|
|
Training log would be similar to this: |
|
|
|
|
|
|
|
.. code:: text |
|
|
|
|
|
|
|
2023/11/29 23:14:17 - abl - INFO - Abductive Learning on the MNIST Add example. |
|
|
|
2023/11/29 23:14:42 - abl - INFO - loop(train) [1/5] segment(train) [1/3] model loss is 1.86793 |
|
|
|
2023/11/29 23:14:44 - abl - INFO - loop(train) [1/5] segment(train) [2/3] model loss is 1.48877 |
|
|
|
2023/11/29 23:14:46 - abl - INFO - loop(train) [1/5] segment(train) [3/3] model loss is 1.26435 |
|
|
|
2023/11/29 23:14:46 - abl - INFO - Evaluation start: loop(val) [1] |
|
|
|
2023/11/29 23:14:47 - abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.334 mnist_add/semantics_accuracy: 0.190 |
|
|
|
2023/11/29 23:14:49 - abl - INFO - loop(train) [2/5] segment(train) [1/3] model loss is 1.06395 |
|
|
|
2023/11/29 23:14:51 - abl - INFO - loop(train) [2/5] segment(train) [2/3] model loss is 0.78799 |
|
|
|
2023/11/29 23:14:53 - abl - INFO - loop(train) [2/5] segment(train) [3/3] model loss is 0.33641 |
|
|
|
2023/11/29 23:14:53 - abl - INFO - Evaluation start: loop(val) [2] |
|
|
|
2023/11/29 23:14:54 - abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.963 mnist_add/semantics_accuracy: 0.926 |
|
|
|
... |
|
|
|
2023/11/29 23:15:08 - abl - INFO - loop(train) [5/5] segment(train) [1/3] model loss is 0.04223 |
|
|
|
2023/11/29 23:15:10 - abl - INFO - loop(train) [5/5] segment(train) [2/3] model loss is 0.03444 |
|
|
|
2023/11/29 23:15:12 - abl - INFO - loop(train) [5/5] segment(train) [3/3] model loss is 0.03274 |
|
|
|
2023/11/29 23:15:12 - abl - INFO - Evaluation start: loop(val) [5] |
|
|
|
2023/11/29 23:15:13 - abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.991 mnist_add/semantics_accuracy: 0.983 |
|
|
|
2023/11/29 23:15:13 - abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.985 mnist_add/semantics_accuracy: 0.970 |