Browse Source

[FIX] resolve some comments

pull/1/head
troyyyyy 2 years ago
parent
commit
dc0c36d8a4
24 changed files with 563 additions and 233 deletions
  1. +2
    -2
      abl/reasoning/kb.py
  2. +135
    -75
      docs/Examples/MNISTAdd.rst
  3. +2
    -2
      docs/Intro/Basics.rst
  4. +1
    -0
      docs/Intro/Bridge.rst
  5. +1
    -0
      docs/Intro/Datasets.rst
  6. +5
    -2
      docs/Intro/Evaluation.rst
  7. +1
    -0
      docs/Intro/Learning.rst
  8. +6
    -3
      docs/Intro/Quick-Start.rst
  9. +4
    -3
      docs/Intro/Reasoning.rst
  10. +1
    -1
      docs/Overview/Abductive-Learning.rst
  11. +0
    -4
      examples/hwf/datasets/README.md
  12. +3
    -0
      examples/hwf/datasets/__init__.py
  13. +61
    -0
      examples/hwf/datasets/get_dataset.py
  14. +0
    -45
      examples/hwf/datasets/get_hwf.py
  15. +70
    -0
      examples/hwf/datasets/test.ipynb
  16. +48
    -9
      examples/hwf/hwf_example.ipynb
  17. +114
    -0
      examples/hwf/main.py
  18. +1
    -0
      examples/hwf/requirements.txt
  19. +4
    -5
      examples/mnist_add/README.md
  20. +2
    -2
      examples/mnist_add/datasets/__init__.py
  21. +12
    -20
      examples/mnist_add/datasets/get_dataset.py
  22. +17
    -17
      examples/mnist_add/main.py
  23. +72
    -40
      examples/mnist_add/mnist_add.ipynb
  24. +1
    -3
      examples/mnist_add/requirements.txt

+ 2
- 2
abl/reasoning/kb.py View File

@@ -43,8 +43,8 @@ class KBBase(ABC):

Notes
-----
Users should inherit from this base class to build their own knowledge base. For the
user-build KB (an inherited subclass), it's only required for the user to provide the
Users should derive from this base class to build their own knowledge base. For the
user-build KB (a derived subclass), it's only required for the user to provide the
`pseudo_label_list` and override the `logic_forward` function (specifying how to
perform logical reasoning). After that, other operations (e.g. how to perform abductive
reasoning) will be automatically set up.


+ 135
- 75
docs/Examples/MNISTAdd.rst View File

@@ -1,49 +1,56 @@
MNIST Addition
==============

This example shows a simple implementation of MNIST Addition, which was
first introduced in `Manhaeve et al.,
2018 <https://arxiv.org/abs/1805.10872>`__. In this task, the inputs are
pairs of MNIST handwritten images, and the outputs are their sums.

In Abductive Learning, we hope to first use learning part to map the
input images to their digits (we call it pseudo labels), and then use
reasoning part to calculate the summation of these pseudo labels to get
the final result.
In this example, we show an implementation of `MNIST
Addition <https://arxiv.org/abs/1805.10872>`_. In this task, pairs of
MNIST handwritten images and their sums are given, alongwith a domain
knowledge base which contain information on how to perform addition
operations. Our objective is to input a pair of handwritten images and
accurately determine their sum.

Intuitively, we first use a machine learning model (learning part) to
convert the input images to digits (we call them pseudo labels), and
then use the knowledge base (reasoning part) to calculate the sum of
these digits. Since we do not have ground-truth of the digits, the
reasoning part will leverage domain knowledge and revise the initial
digits yielded by the learning part into results derived from abductive
reasoning. This process enables us to further refine and retrain the
machine learning model.

.. code:: ipython3

# Import necessary libraries and modules
import os.path as osp
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from abl.bridge import SimpleBridge
from abl.evaluation import ReasoningMetric, SymbolMetric
from examples.mnist_add.datasets import get_dataset
from examples.models.nn import LeNet5
from abl.learning import ABLModel, BasicNN
from abl.reasoning import KBBase, Reasoner
from abl.evaluation import ReasoningMetric, SymbolMetric
from abl.utils import ABLLogger, print_log
from examples.mnist_add.datasets import get_mnist_add
from examples.models.nn import LeNet5
from abl.bridge import SimpleBridge

Load Datasets
-------------
Working with Data
-----------------

First, we get training and testing data:
First, we get the training and testing datasets:

.. code:: ipython3

train_data = get_mnist_add(train=True, get_pseudo_label=True)
test_data = get_mnist_add(train=False, get_pseudo_label=True)
train_data = get_dataset(train=True, get_pseudo_label=True)
test_data = get_dataset(train=False, get_pseudo_label=True)

The datasets are illustrated as follows:
Both datasets contain several data examples. In each data example, we
have three components: X (a pair of images), gt_pseudo_label (a pair of
corresponding ground truth digits, i.e., pseudo labels), and Y (their sum). The datasets are illustrated
as follows.

.. code:: ipython3

print(f"There are {len(train_data[0])} data examples in the training set and {len(test_data[0])} data examples in the test set")
print(f"Each of the data example has {len(train_data)} components: X, gt_pseudo_label, and Y.")
print("As an illustration, in the First data example of the training set, we have:")
print("As an illustration, in the first data example of the training set, we have:")
print(f"X ({len(train_data[0][0])} images):")
plt.subplot(1,2,1)
plt.axis('off')
@@ -56,36 +63,39 @@ The datasets are illustrated as follows:
print(f"Y (their sum result): {train_data[2][0]}")

Out:
.. code:: none
:class: code-out

There are 30000 data examples in the training set and 5000 data examples in the test set
Each of the data example has 3 components: X, gt_pseudo_label, and Y.
As an illustration, in the First data example of the training set, we have:
X (2 images):

.. image:: ../img/mnist_add_datasets.png
:width: 400px

.. code:: none
:class: code-out
.. code:: none
:class: code-out
There are 30000 data examples in the training set and 5000 data examples in the test set
As an illustration, in the first data example of the training set, we have:
X (2 images):
.. image:: ../img/mnist_add_datasets.png
:width: 400px

gt_pseudo_label (2 ground truth pseudo label): 7, 5
Y (their sum result): 12

.. code:: none
:class: code-out

gt_pseudo_label (2 ground truth pseudo label): 7, 5
Y (their sum result): 12

Learning Part
-------------
Building the Learning Part
--------------------------

First, we build the basic learning model. We use a simple `LeNet neural
network <https://en.wikipedia.org/wiki/LeNet>`__ to complete this task.
To build the learning part, we need to first build a base machine
learning model. We use a simple `LeNet-5 neural
network <https://en.wikipedia.org/wiki/LeNet>`__ to complete this task,
and encapsulate it within a ``BasicNN`` object to create the base model.
``BasicNN`` is a class that encapsulates a PyTorch model, transforming
it into a base model with an sklearn-style interface.

.. code:: ipython3

cls = LeNet5(num_classes=10)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99))
optimizer = torch.optim.Adam(cls.parameters(), lr=0.001)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
base_model = BasicNN(
@@ -97,8 +107,9 @@ network <https://en.wikipedia.org/wiki/LeNet>`__ to complete this task.
num_epochs=1,
)

The base model can predict the outcome class index and the probabilities
for an image, as shown below:
``BasicNN`` offers methods like ``predict`` and ``predict_prob``, which
are used to predict the outcome class index and the probabilities for an
image, respectively. As shown below:

.. code:: ipython3

@@ -107,42 +118,59 @@ for an image, as shown below:
pred_prob = base_model.predict_proba(X=[torch.randn(1, 28, 28).to(device) for _ in range(32)])
print(f"Shape of pred_prob for a batch of 32 samples: {pred_prob.shape}")


Out:
.. code:: none
:class: code-out
.. code:: none
:class: code-out

Shape of pred_idx for a batch of 32 samples: (32,)
Shape of pred_prob for a batch of 32 samples: (32, 10)
Shape of pred_idx for a batch of 32 samples: (32,)
Shape of pred_prob for a batch of 32 samples: (32, 10)

Then, we build an instance of ``ABLModel``. The main function of
``ABLModel`` is to serialize data and provide a unified interface for
different base machine learning models.
However, base model built above are trained to make predictions on
instance-level data, i.e., a single image, and can not directly utilize
sample-level data, i.e., a pair of images. Therefore, we then wrap the
base model into ``ABLModel`` which enables the learning part to train,
test, and predict on sample-level data.

.. code:: ipython3

model = ABLModel(base_model)

Logic Part
----------
TODO: 示例展示ablmodel和base model的predict的不同

.. code:: ipython3

# from abl.structures import ListData
# data_samples = ListData()
# data_samples.X = [list(torch.randn(2, 1, 28, 28)) for _ in range(3)]
# model.predict(data_samples)

Building the Reasoning Part
---------------------------

In the logic part, we first build a knowledge base.
In the reasoning part, we first build a knowledge base which contain
information on how to perform addition operations. We build it by
creating a subclass of ``KBBase``. In the derived subclass, we have to
first initialize the ``pseudo_label_list`` parameter specifying list of
possible pseudo labels, and then override the ``logic_forward`` function
defining how to perform (deductive) reasoning.

.. code:: ipython3

# Build knowledge base and reasoner
class AddKB(KBBase):
def __init__(self, pseudo_label_list):
def __init__(self, pseudo_label_list=list(range(10))):
super().__init__(pseudo_label_list)
# Implement the deduction function
def logic_forward(self, nums):
return sum(nums)
kb = AddKB(pseudo_label_list=list(range(10)))
kb = AddKB()

The knowledge base can perform logical reasoning. Below is an example of
performing (deductive) reasoning:
performing (deductive) reasoning: # TODO: ABDUCTIVE REASONING

.. code:: ipython3

@@ -150,25 +178,57 @@ performing (deductive) reasoning:
reasoning_result = kb.logic_forward(pseudo_label_sample)
print(f"Reasoning result of pseudo label sample {pseudo_label_sample} is {reasoning_result}.")


Out:
.. code:: none
:class: code-out
.. code:: none
:class: code-out

Reasoning result of pseudo label sample [1, 2] is 3.
Reasoning result of pseudo label sample [1, 2] is 3.

Then, we create a reasoner. It can help minimize inconsistencies between
the knowledge base and pseudo labels predicted by the learning part.
.. note::

In addition to building a knowledge base based on ``KBBase``, we
can also establish a knowledge base with a ground KB using ``GroundKB``,
or a knowledge base implemented based on Prolog files using
``PrologKB``. The corresponding code for these implementations can be
found in the ``examples/mnist_add/main.py`` file. Those interested are encouraged to
examine it for further insights.

Then, we create a reasoner by instantiating the class ``Reasoner``. Due
to the indeterminism of abductive reasoning, there could be multiple
candidates compatible to the knowledge base. When this happens, reasoner
can minimize inconsistencies between the knowledge base and pseudo
labels predicted by the learning part, and then return only one
candidate which has highest consistency.

.. code:: ipython3

reasoner = Reasoner(kb, dist_func="confidence")
reasoner = Reasoner(kb)

.. note::

During creating reasoner, the definition of “consistency” can be
customized within the ``dist_func`` parameter. In the code above, we
employ a consistency measurement based on confidence, which calculates
the consistency between the data sample and candidates based on the
confidence derived from the predicted probability. In ``examples/mnist_add/main.py``, we
provide options for utilizing other forms of consistency measurement.

Evaluation Metrics
------------------
Also, during process of inconsistency minimization, one can leverage
`ZOOpt library <https://github.com/polixir/ZOOpt>`__ for acceleration.
Options for this are also available in ``examples/mnist_add/main.py``. Those interested are
encouraged to explore these features.

Set up evaluation metrics. These metrics will be used to evaluate the
model performance during training and testing.
Building Evaluation Metrics
---------------------------

Next, we set up evaluation metrics. These metrics will be used to
evaluate the model performance during training and testing.
Specifically, we use ``SymbolMetric`` and ``ReasoningMetric``, which are
used to evaluate the accuracy of the machine learning model’s
predictions and the accuracy of the final reasoning results,
respectively.

.. code:: ipython3

@@ -177,23 +237,23 @@ model performance during training and testing.
Bridge Learning and Reasoning
-----------------------------

Now, the last step is to bridge the learning and reasoning part.
Now, the last step is to bridge the learning and reasoning part. We
proceed this step by creating an instance of ``SimpleBridge``.

.. code:: ipython3

bridge = SimpleBridge(model, reasoner, metric_list)

Perform training and testing.
Perform training and testing by invoking the ``train`` and ``test``
methods of ``SimpleBridge``.

.. code:: ipython3

# Build logger
print_log("Abductive Learning on the MNIST Addition example.", logger="current")
# Retrieve the directory of the Log file and define the directory for saving the model weights.
log_dir = ABLLogger.get_current_instance().log_dir
weights_dir = osp.join(log_dir, "weights")
bridge.train(train_data, loops=5, segment_size=1/3, save_interval=1, save_dir=weights_dir)
bridge.test(test_data)

@@ -254,4 +314,4 @@ Out:
abl - INFO - Checkpoints will be saved to log_dir/weights/model_checkpoint_loop_5.pth
abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.002 mnist_add/reasoning_accuracy: 0.482
More concrete examples are available in ``examples/mnist_add`` folder.
More concrete examples are available in ``examples/mnist_add/main.py`` and ``examples/mnist_add/mnist_add.ipynb``.

+ 2
- 2
docs/Intro/Basics.rst View File

@@ -22,7 +22,7 @@ AI: data, models, and knowledge.
.. image:: ../img/ABL-Package.png

**Data** module manages the storage, operation, and evaluation of data.
It first features class ``ListData`` (inherited from base class
It first features class ``ListData`` (derived from base class
``BaseDataElement``), which defines the data structures used in
Abductive Learning, and comprises common data operations like insertion,
deletion, retrieval, slicing, etc. Additionally, a series of Evaluation
@@ -46,7 +46,7 @@ responsible for minimizing the inconsistency between the knowledge base
and learning models.

Finally, the integration of these three modules occurs through
**Bridge** module, which features class ``SimpleBridge`` (inherited from base
**Bridge** module, which features class ``SimpleBridge`` (derived from base
class ``BaseBridge``). Bridge module synthesize data, learning, and
reasoning, and facilitate the training and testing of the entire
Abductive Learning framework.


+ 1
- 0
docs/Intro/Bridge.rst View File

@@ -14,6 +14,7 @@ In this section, we will look at how to bridge learning and reasoning parts to t

.. code:: python

# Import necessary modules
from abl.bridge import BaseBridge, SimpleBridge

``BaseBridge`` is an abstract class with the following initialization parameters:


+ 1
- 0
docs/Intro/Datasets.rst View File

@@ -14,6 +14,7 @@ In this section, we will look at the datasets and data structures in ABL-Package

.. code:: python

# Import necessary libraries and modules
import torch
from abl.structures import ListData



+ 5
- 2
docs/Intro/Evaluation.rst View File

@@ -10,19 +10,22 @@
Evaluation Metrics
==================

In this section, we will look at how to build evaluation metrics. ABL-Package seperates the evaluation process from model training and testing as an independent class, ``BaseMetric``. The training and testing processes are implemented in the ``BaseBridge`` class, so metrics are used by this class and its sub-classes. After building a ``bridge`` with a list of ``BaseMetric`` instances, these metrics will be used by the ``bridge.valid`` method to evaluate the model performance during training and testing.
In this section, we will look at how to build evaluation metrics.

.. code:: python

# Import necessary modules
from abl.evaluation import BaseMetric, SymbolMetric, ReasoningMetric

ABL-Package seperates the evaluation process from model training and testing as an independent class, ``BaseMetric``. The training and testing processes are implemented in the ``BaseBridge`` class, so metrics are used by this class and its sub-classes. After building a ``bridge`` with a list of ``BaseMetric`` instances, these metrics will be used by the ``bridge.valid`` method to evaluate the model performance during training and testing.

To customize our own metrics, we need to inherit from ``BaseMetric`` and implement the ``process`` and ``compute_metrics`` methods.

- The ``process`` method accepts a batch of model prediction and saves the information to ``self.results`` property after processing this batch.
- The ``compute_metrics`` method uses all the information saved in ``self.results`` to calculate and return a dict that holds the evaluation results.

Besides, we can assign a ``str`` to the ``prefix`` argument of the ``__init__`` method. This string is automatically prefixed to the output metric names. For example, if we set ``prefix="mnist_add"``, the output metric name will be ``character_accuracy``.
We provide two basic metrics, namely ``SymbolMetric`` and ``ReasoningMetric``, which are used to evaluate the accuracy of the machine learning model's predictions and the accuracy of the ``logic_forward`` results, respectively. Using ``SymbolMetric`` as an example, the following code shows how to implement a custom metrics.
We provide two basic metrics, namely ``SymbolMetric`` and ``ReasoningMetric``, which are used to evaluate the accuracy of the machine learning model's predictions and the accuracy of the final reasoning results, respectively. Using ``SymbolMetric`` as an example, the following code shows how to implement a custom metrics.

.. code:: python



+ 1
- 0
docs/Intro/Learning.rst View File

@@ -14,6 +14,7 @@ In this section, we will look at how to build the learning part. In ABL-Package,

.. code:: python

# Import necessary libraries and modules
import sklearn
import torchvision
from abl.learning import BasicNN, ABLModel


+ 6
- 3
docs/Intro/Quick-Start.rst View File

@@ -9,7 +9,7 @@
Quick Start
===========

We use the MNIST Addition task as a quick start example. In this task, the inputs are pairs of MNIST handwritten images, and the outputs are their sums. Refer to the links in each section to dive deeper.
We use the MNIST Addition task as a quick start example. In this task, pairs of MNIST handwritten images and their sums are given, alongwith a domain knowledge base which contain information on how to perform addition operations. Our objective is to input a pair of handwritten images and accurately determine their sum. Refer to the links in each section to dive deeper.

Working with Data
-----------------
@@ -90,8 +90,11 @@ function specifying how to perform (deductive) reasoning.

Then, we create a reasoner by instantiating the class
``Reasoner`` and passing the knowledge base as an parameter.
The reasoner can be used to minimize inconsistencies between the
knowledge base and the prediction from the learning part.
Due to the indeterminism of abductive reasoning, there could
be multiple candidates compatible to the knowledge base.
When this happens, reasoner can minimize inconsistencies between
the knowledge base and pseudo labels predicted by the learning part,
and then return only one candidate which has highest consistency.

.. code:: python



+ 4
- 3
docs/Intro/Reasoning.rst View File

@@ -22,12 +22,13 @@ In ABL-Package, building the reasoning part involves two steps:

.. code:: python

# Import necessary modules
from abl.reasoning import KBBase, GroundKB, PrologKB, Reasoner

Building a knowledge base
-------------------------

Generally, we can create a subclass inherited from ``KBBase`` to build our own
Generally, we can create a subclass derived from ``KBBase`` to build our own
knowledge base. In addition, ABL-Package also offers several predefined
subclasses of ``KBBase`` (e.g., ``PrologKB`` and ``GroundKB``),
which we can utilize to build our knowledge base more conveniently.
@@ -35,7 +36,7 @@ which we can utilize to build our knowledge base more conveniently.
Building a knowledge base from `KBBase`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

For the user-built KB from `KBBase` (an inherited subclass), it's only
For the user-built KB from ``KBBase`` (a derived subclass), it's only
required to pass the ``pseudo_label_list`` parameter in the ``__init__`` function
and override the ``logic_forward`` function:

@@ -184,7 +185,7 @@ As an example, the ``GKB_len_list`` for MNIST Addition should be ``[2]``,
since all pseudo labels in the example consist of two digits. Therefore,
the construction of KB with GKB (``add_ground_kb``) of MNIST Addition would be
as follows. As mentioned, the difference between this and the previously
built ``add_kb`` lies only in the base class from which it is inherited
built ``add_kb`` lies only in the base class from which it is derived
and whether an extra parameter ``GKB_len_list`` is passed.

.. code:: python


+ 1
- 1
docs/Overview/Abductive-Learning.rst View File

@@ -51,7 +51,7 @@ The following figure illustrates this process:

We can observe that in the above figure, the left half involves machine
learning, while the right half involves logical reasoning. Thus, the
entire abductive learning process is a continuous cycle of machine
entire Abductive Learning process is a continuous cycle of machine
learning and logical reasoning. This effectively forms a paradigm that
is dual-driven by both data and domain knowledge, integrating and
balancing the use of machine learning and logical reasoning in a unified


+ 0
- 4
examples/hwf/datasets/README.md View File

@@ -1,4 +0,0 @@
Download the Handwritten Formula Recognition dataset from [google drive](https://drive.google.com/file/d/1G07kw-wK-rqbg_85tuB7FNfA49q8lvoy/view?usp=sharing) to this folder and unzip it:
```
unzip HWF.zip
```

+ 3
- 0
examples/hwf/datasets/__init__.py View File

@@ -0,0 +1,3 @@
from .get_dataset import get_dataset

__all__ = ["get_dataset"]

+ 61
- 0
examples/hwf/datasets/get_dataset.py View File

@@ -0,0 +1,61 @@
import json
import os
import gdown
import zipfile

from PIL import Image
from torchvision.transforms import transforms

CURRENT_DIR = os.path.abspath(os.path.dirname(__file__))

img_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1,))])

def download_and_unzip(url, zip_file_name):
try:
gdown.download(url, zip_file_name)
with zipfile.ZipFile(zip_file_name, 'r') as zip_ref:
zip_ref.extractall()
os.remove(zip_file_name)
except Exception as e:
if os.path.exists(zip_file_name):
os.remove(zip_file_name)
raise Exception(f"An error occurred during download or unzip: {e}. Instead, you can download the dataset from {url} and unzip it in './datasets' folder")

def get_dataset(train=True, get_pseudo_label=False):
data_dir = CURRENT_DIR + '/data'
url = 'https://drive.google.com/u/0/uc?id=1G07kw-wK-rqbg_85tuB7FNfA49q8lvoy&export=download'

if not os.path.exists(data_dir):
print("Dataset not exist, downloading it...")
download_and_unzip(url, 'HWF.zip')
print("Download and extraction complete.")
if train:
file = os.path.join(data_dir, "expr_train.json")
else:
file = os.path.join(data_dir, "expr_test.json")

X = []
Z = [] if get_pseudo_label else None
Y = []
img_dir = os.path.join(CURRENT_DIR, "data/Handwritten_Math_Symbols/")
with open(file) as f:
data = json.load(f)
for idx in range(len(data)):
imgs = []
if get_pseudo_label:
imgs_pseudo_label = []
for img_path in data[idx]["img_paths"]:
img = Image.open(img_dir + img_path).convert("L")
img = img_transform(img)
imgs.append(img)
if get_pseudo_label:
imgs_pseudo_label.append(img_path.split("/")[0])
X.append(imgs)
if get_pseudo_label:
Z.append(imgs_pseudo_label)
Y.append(data[idx]["res"])

return X, Z, Y

get_dataset()

+ 0
- 45
examples/hwf/datasets/get_hwf.py View File

@@ -1,45 +0,0 @@
import json
import os

from PIL import Image
from torchvision.transforms import transforms

CURRENT_DIR = os.path.abspath(os.path.dirname(__file__))

img_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1,))])


def get_data(file, get_pseudo_label):
X, Y = [], []
if get_pseudo_label:
Z = []
img_dir = os.path.join(CURRENT_DIR, "data/Handwritten_Math_Symbols/")
with open(file) as f:
data = json.load(f)
for idx in range(len(data)):
imgs = []
imgs_pseudo_label = []
for img_path in data[idx]["img_paths"]:
img = Image.open(img_dir + img_path).convert("L")
img = img_transform(img)
imgs.append(img)
if get_pseudo_label:
imgs_pseudo_label.append(img_path.split("/")[0])
X.append(imgs)
if get_pseudo_label:
Z.append(imgs_pseudo_label)
Y.append(data[idx]["res"])

if get_pseudo_label:
return X, Z, Y
else:
return X, None, Y


def get_hwf(train=True, get_gt_pseudo_label=False):
if train:
file = os.path.join(CURRENT_DIR, "data/expr_train.json")
else:
file = os.path.join(CURRENT_DIR, "data/expr_test.json")

return get_data(file, get_gt_pseudo_label)

+ 70
- 0
examples/hwf/datasets/test.ipynb View File

@@ -0,0 +1,70 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"ename": "NameError",
"evalue": "name '__file__' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"Input \u001b[0;32mIn [1]\u001b[0m, in \u001b[0;36m<cell line: 10>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mPIL\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Image\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorchvision\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mtransforms\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m transforms\n\u001b[0;32m---> 10\u001b[0m CURRENT_DIR \u001b[38;5;241m=\u001b[39m os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mabspath(os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mdirname(\u001b[38;5;18;43m__file__\u001b[39;49m))\n\u001b[1;32m 12\u001b[0m img_transform \u001b[38;5;241m=\u001b[39m transforms\u001b[38;5;241m.\u001b[39mCompose([transforms\u001b[38;5;241m.\u001b[39mToTensor(), transforms\u001b[38;5;241m.\u001b[39mNormalize((\u001b[38;5;241m0.5\u001b[39m,), (\u001b[38;5;241m1\u001b[39m,))])\n\u001b[1;32m 14\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mrequests\u001b[39;00m\n",
"\u001b[0;31mNameError\u001b[0m: name '__file__' is not defined"
]
}
],
"source": [
"import json\n",
"import os\n",
"import requests\n",
"import os\n",
"import zipfile\n",
"\n",
"from PIL import Image\n",
"from torchvision.transforms import transforms\n",
"\n",
"CURRENT_DIR = os.path.abspath(os.path.dirname(__file__))\n",
"\n",
"img_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1,))])\n",
"\n",
"import requests\n",
"import os\n",
"import zipfile\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "abl",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

+ 48
- 9
examples/hwf/hwf_example.ipynb View File

@@ -1,8 +1,26 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Handwritten Formula (HWF)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This example shows a simple implementation of Handwritten Formula, which was first introduced in [Li et al., 2020](https://arxiv.org/abs/2006.06649). In this task, the inputs are images of decimal formulas, and the outputs are their computed results.\n",
"\n",
"In Abductive Learning, we hope to first use learning part to map the input images to their symbols (we call them pseudo labels), and then use reasoning part to calculate the summation of these pseudo labels to get the final result.\n",
"\n",
"The HWF dataset ontains images of decimal formulas and their computed results. "
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
@@ -18,14 +36,22 @@
"from abl.utils import ABLLogger, print_log\n",
"\n",
"from examples.models.nn import SymbolNet\n",
"from datasets.get_hwf import get_hwf"
"from examples.hwf.datasets.get_dataset import get_dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"12/18 12:48:19 - abl - INFO - Abductive Learning on the HWF example.\n"
]
}
],
"source": [
"# Initialize logger and print basic information\n",
"print_log(\"Abductive Learning on the HWF example.\", logger=\"current\")\n",
@@ -159,13 +185,26 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {},
"outputs": [],
"outputs": [
{
"ename": "FileNotFoundError",
"evalue": "[Errno 2] 没有那个文件或目录: '/home/huwc/ABL-Package/examples/hwf/datasets/data/expr_train.json'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)",
"Input \u001b[0;32mIn [4]\u001b[0m, in \u001b[0;36m<cell line: 2>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# Get training and testing data\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m train_data \u001b[38;5;241m=\u001b[39m \u001b[43mget_dataset\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrain\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mget_pseudo_label\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m 3\u001b[0m test_data \u001b[38;5;241m=\u001b[39m get_dataset(train\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m, get_pseudo_label\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n",
"File \u001b[0;32m~/ABL-Package/examples/hwf/datasets/get_dataset.py:21\u001b[0m, in \u001b[0;36mget_dataset\u001b[0;34m(train, get_pseudo_label)\u001b[0m\n\u001b[1;32m 19\u001b[0m Y \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m 20\u001b[0m img_dir \u001b[38;5;241m=\u001b[39m os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(CURRENT_DIR, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdata/Handwritten_Math_Symbols/\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m---> 21\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28;43mopen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mfile\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mas\u001b[39;00m f:\n\u001b[1;32m 22\u001b[0m data \u001b[38;5;241m=\u001b[39m json\u001b[38;5;241m.\u001b[39mload(f)\n\u001b[1;32m 23\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m idx \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mlen\u001b[39m(data)):\n",
"\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] 没有那个文件或目录: '/home/huwc/ABL-Package/examples/hwf/datasets/data/expr_train.json'"
]
}
],
"source": [
"# Get training and testing data\n",
"train_data = get_hwf(train=True, get_gt_pseudo_label=True)\n",
"test_data = get_hwf(train=False, get_gt_pseudo_label=True)"
"train_data = get_dataset(train=True, get_pseudo_label=True)\n",
"test_data = get_dataset(train=False, get_pseudo_label=True)"
]
},
{
@@ -220,7 +259,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.18"
"version": "3.8.13"
},
"orig_nbformat": 4,
"vscode": {


+ 114
- 0
examples/hwf/main.py View File

@@ -0,0 +1,114 @@
# %%
import torch
import numpy as np
import torch.nn as nn
import os.path as osp

from abl.reasoning import Reasoner, KBBase
from abl.learning import BasicNN, ABLModel
from abl.bridge import SimpleBridge
from abl.evaluation import SymbolMetric, ReasoningMetric
from abl.utils import ABLLogger, print_log

from examples.models.nn import SymbolNet
from examples.hwf.datasets.get_dataset import get_hwf

# %%
# Initialize logger and print basic information
print_log("Abductive Learning on the HWF example.", logger="current")

# Retrieve the directory of the Log file and define the directory for saving the model weights.
log_dir = ABLLogger.get_current_instance().log_dir
weights_dir = osp.join(log_dir, "weights")

# %% [markdown]
# ### Logic Part

# %%
# Initialize knowledge base and reasoner
class HWF_KB(KBBase):
def _valid_candidate(self, formula):
if len(formula) % 2 == 0:
return False
for i in range(len(formula)):
if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]:
return False
if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]:
return False
return True

def logic_forward(self, formula):
if not self._valid_candidate(formula):
return np.inf
mapping = {str(i): str(i) for i in range(1, 10)}
mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"})
formula = [mapping[f] for f in formula]
return eval("".join(formula))


kb = HWF_KB(
pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", "+", "-", "times", "div"],
max_err=1e-10,
use_cache=False,
)
reasoner = Reasoner(kb, dist_func="confidence")

# %% [markdown]
# ### Machine Learning Part

# %%
# Initialize necessary component for machine learning part
cls = SymbolNet(num_classes=len(kb.pseudo_label_list), image_size=(45, 45, 1))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99))

# %%
# Initialize BasicNN
# The function of BasicNN is to wrap NN models into the form of an sklearn estimator
base_model = BasicNN(
model=cls,
loss_fn=loss_fn,
optimizer=optimizer,
device=device,
save_interval=1,
save_dir=weights_dir,
batch_size=128,
num_epochs=3,
)

# %%
# Initialize ABL model
# The main function of the ABL model is to serialize data and
# provide a unified interface for different machine learning models
model = ABLModel(base_model)

# %% [markdown]
# ### Metric

# %%
# Add metric
metric_list = [SymbolMetric(prefix="hwf"), ReasoningMetric(kb=kb, prefix="hwf")]

# %% [markdown]
# ### Dataset

# %%
# Get training and testing data
train_data = get_hwf(train=True, get_pseudo_label=True)
test_data = get_hwf(train=False, get_pseudo_label=True)

# %% [markdown]
# ### Bridge Machine Learning and Logic Reasoning

# %%
bridge = SimpleBridge(model=model, reasoner=reasoner, metric_list=metric_list)

# %% [markdown]
# ### Train and Test

# %%
bridge.train(train_data, train_data, loops=3, segment_size=1000, save_interval=1, save_dir=weights_dir)
bridge.test(test_data)



+ 1
- 0
examples/hwf/requirements.txt View File

@@ -0,0 +1 @@
gdown

+ 4
- 5
examples/mnist_add/README.md View File

@@ -4,16 +4,15 @@ This example shows a simple implementation of [MNIST Addition](https://link) tas

## Run

```
bash
```bash
pip install -r requirements.txt
python main.py
```

## Usage

```
usage: test.py [-h] [--no-cuda] [--epochs EPOCHS] [--lr LR]
```bash
usage: main.py [-h] [--no-cuda] [--epochs EPOCHS] [--lr LR]
[--weight-decay WEIGHT_DECAY] [--batch-size BATCH_SIZE]
[--loops LOOPS] [--segment_size SEGMENT_SIZE]
[--save_interval SAVE_INTERVAL] [--max-revision MAX_REVISION]
@@ -34,7 +33,7 @@ optional arguments:
batch size (default : 32)
--loops LOOPS number of loop iterations (default : 5)
--segment_size SEGMENT_SIZE
number of loop iterations (default : 1/3)
segment size (default : 1/3)
--save_interval SAVE_INTERVAL
save interval (default : 1)
--max-revision MAX_REVISION


+ 2
- 2
examples/mnist_add/datasets/__init__.py View File

@@ -1,3 +1,3 @@
from .get_mnist_add import get_mnist_add
from .get_dataset import get_dataset

__all__ = ["get_mnist_add"]
__all__ = ["get_dataset"]

examples/mnist_add/datasets/get_mnist_add.py → examples/mnist_add/datasets/get_dataset.py View File

@@ -5,25 +5,7 @@ from torchvision.transforms import transforms

CURRENT_DIR = os.path.abspath(os.path.dirname(__file__))

def get_data(file, img_dataset, get_pseudo_label):
X = []
if get_pseudo_label:
Z = []
Y = []
with open(file) as f:
for line in f:
line = line.strip().split(" ")
X.append([img_dataset[int(line[0])][0], img_dataset[int(line[1])][0]])
if get_pseudo_label:
Z.append([img_dataset[int(line[0])][1], img_dataset[int(line[1])][1]])
Y.append(int(line[2]))

if get_pseudo_label:
return X, Z, Y
else:
return X, None, Y

def get_mnist_add(train=True, get_pseudo_label=False):
def get_dataset(train=True, get_pseudo_label=False):
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
@@ -35,4 +17,14 @@ def get_mnist_add(train=True, get_pseudo_label=False):
else:
file = os.path.join(CURRENT_DIR, "test_data.txt")

return get_data(file, img_dataset, get_pseudo_label)
X = []
Z = [] if get_pseudo_label else None
Y = []
with open(file) as f:
for line in f:
x1, x2, y = map(int, line.strip().split(" "))
X.append([img_dataset[x1][0], img_dataset[x2][0]])
if get_pseudo_label:
Z.append([img_dataset[x1][1], img_dataset[x2][1]])
Y.append(y)
return X, Z, Y

+ 17
- 17
examples/mnist_add/main.py View File

@@ -5,13 +5,13 @@ import argparse
import torch
from torch import nn

from abl.bridge import SimpleBridge
from abl.evaluation import ReasoningMetric, SymbolMetric
from examples.mnist_add.datasets import get_dataset
from examples.models.nn import LeNet5
from abl.learning import ABLModel, BasicNN
from abl.reasoning import KBBase, GroundKB, PrologKB, Reasoner
from abl.evaluation import ReasoningMetric, SymbolMetric
from abl.utils import ABLLogger, print_log
from examples.mnist_add.datasets import get_mnist_add
from examples.models.nn import LeNet5
from abl.bridge import SimpleBridge

class AddKB(KBBase):
def __init__(self, pseudo_label_list=list(range(10))):
@@ -42,7 +42,7 @@ def main():
parser.add_argument('--loops', type=int, default=5,
help='number of loop iterations (default : 5)')
parser.add_argument('--segment_size', type=int or float, default=1/3,
help='number of loop iterations (default : 1/3)')
help='segment size (default : 1/3)')
parser.add_argument('--save_interval', type=int, default=1,
help='save interval (default : 1)')
parser.add_argument('--max-revision', type=int or float, default=-1,
@@ -56,8 +56,12 @@ def main():
help='use GroundKB (default: False)')

args = parser.parse_args()
### Working with Data
train_data = get_dataset(train=True, get_pseudo_label=True)
test_data = get_dataset(train=False, get_pseudo_label=True)

### Learning Part
### Building the Learning Part
# Build necessary components for BasicNN
cls = LeNet5(num_classes=10)
loss_fn = nn.CrossEntropyLoss()
@@ -66,7 +70,6 @@ def main():
device = torch.device("cuda" if use_cuda else "cpu")

# Build BasicNN
# The function of BasicNN is to wrap NN models into the form of an sklearn estimator
base_model = BasicNN(
cls,
loss_fn,
@@ -77,27 +80,24 @@ def main():
)

# Build ABLModel
# The main function of the ABL model is to serialize data and
# provide a unified interface for different machine learning models
model = ABLModel(base_model)
### Building the Reasoning Part
# Build knowledge base
if args.prolog:
kb = PrologKB(pseudo_label_list=list(range(10)), pl_file="add.pl")
elif args.ground:
kb = AddGroundKB()
else:
kb = AddKB()
reasoner = Reasoner(kb, dist_func="confidence", max_revision=args.max_revision, require_more_revision=args.require_more_revision)

### Evaluation Metrics
# Get training and testing data
train_data = get_mnist_add(train=True, get_pseudo_label=True)
test_data = get_mnist_add(train=False, get_pseudo_label=True)
# Create reasoner
reasoner = Reasoner(kb, max_revision=args.max_revision, require_more_revision=args.require_more_revision)

# Set up metrics
### Building Evaluation Metrics
metric_list = [SymbolMetric(prefix="mnist_add"), ReasoningMetric(kb=kb, prefix="mnist_add")]

### Bridge Machine Learning and Logic Reasoning
### Bridge Learning and Reasoning
bridge = SimpleBridge(model, reasoner, metric_list)

# Build logger


examples/mnist_add/mnist_add_example.ipynb → examples/mnist_add/mnist_add.ipynb View File

@@ -6,9 +6,9 @@
"source": [
"# MNIST Addition\n",
"\n",
"This example shows a simple implementation of MNIST Addition, which was first introduced in [Manhaeve et al., 2018](https://arxiv.org/abs/1805.10872). In this task, the inputs are pairs of MNIST handwritten images, and the outputs are their sums.\n",
"This notebook shows an implementation of [MNIST Addition](https://arxiv.org/abs/1805.10872). In this task, pairs of MNIST handwritten images and their sums are given, alongwith a domain knowledge base which contain information on how to perform addition operations. Our objective is to input a pair of handwritten images and accurately determine their sum.\n",
"\n",
"In Abductive Learning, we hope to first use learning part to map the input images to their digits (we call it pseudo labels), and then use reasoning part to calculate the summation of these pseudo labels to get the final result."
"Intuitively, we first use a machine learning model (learning part) to convert the input images to digits (we call them pseudo labels), and then use the knowledge base (reasoning part) to calculate the sum of these digits. Since we do not have ground-truth of the digits, the reasoning part will leverage domain knowledge and revise the initial digits yielded by the learning part into results derived from abductive reasoning. This process enables us to further refine and retrain the machine learning model."
]
},
{
@@ -17,28 +17,27 @@
"metadata": {},
"outputs": [],
"source": [
"# Import necessary libraries and modules\n",
"import os.path as osp\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from abl.bridge import SimpleBridge\n",
"from abl.evaluation import ReasoningMetric, SymbolMetric\n",
"from examples.mnist_add.datasets import get_dataset\n",
"from examples.models.nn import LeNet5\n",
"from abl.learning import ABLModel, BasicNN\n",
"from abl.reasoning import KBBase, Reasoner\n",
"from abl.evaluation import ReasoningMetric, SymbolMetric\n",
"from abl.utils import ABLLogger, print_log\n",
"from examples.mnist_add.datasets import get_mnist_add\n",
"from examples.models.nn import LeNet5"
"from abl.bridge import SimpleBridge"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load Datasets\n",
"## Working with Data\n",
"\n",
"First, we get training and testing data:"
"First, we get the training and testing datasets:"
]
},
{
@@ -47,15 +46,15 @@
"metadata": {},
"outputs": [],
"source": [
"train_data = get_mnist_add(train=True, get_pseudo_label=True)\n",
"test_data = get_mnist_add(train=False, get_pseudo_label=True)"
"train_data = get_dataset(train=True, get_pseudo_label=True)\n",
"test_data = get_dataset(train=False, get_pseudo_label=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The datasets are illustrated as follows:"
"Both datasets contain several data examples. In each data example, we have three components: X (a pair of images), gt_pseudo_label (a pair of corresponding ground truth digits, i.e., pseudo labels), and Y (their sum). The datasets are illustrated as follows. "
]
},
{
@@ -68,8 +67,7 @@
"output_type": "stream",
"text": [
"There are 30000 data examples in the training set and 5000 data examples in the test set\n",
"Each of the data example has 3 components: X, gt_pseudo_label, and Y.\n",
"As an illustration, in the First data example of the training set, we have:\n",
"As an illustration, in the first data example of the training set, we have:\n",
"X (2 images):\n"
]
},
@@ -94,8 +92,7 @@
],
"source": [
"print(f\"There are {len(train_data[0])} data examples in the training set and {len(test_data[0])} data examples in the test set\")\n",
"print(f\"Each of the data example has {len(train_data)} components: X, gt_pseudo_label, and Y.\")\n",
"print(\"As an illustration, in the First data example of the training set, we have:\")\n",
"print(\"As an illustration, in the first data example of the training set, we have:\")\n",
"print(f\"X ({len(train_data[0][0])} images):\")\n",
"plt.subplot(1,2,1)\n",
"plt.axis('off') \n",
@@ -113,14 +110,14 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Learning Part"
"## Building the Learning Part"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First, we build the basic learning model. We use a simple [LeNet neural network](https://en.wikipedia.org/wiki/LeNet) to complete this task."
"To build the learning part, we need to first build a base machine learning model. We use a simple [LeNet-5 neural network](https://en.wikipedia.org/wiki/LeNet) to complete this task, and encapsulate it within a `BasicNN` object to create the base model. `BasicNN` is a class that encapsulates a PyTorch model, transforming it into a base model with an sklearn-style interface. "
]
},
{
@@ -131,7 +128,7 @@
"source": [
"cls = LeNet5(num_classes=10)\n",
"loss_fn = nn.CrossEntropyLoss()\n",
"optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99))\n",
"optimizer = torch.optim.Adam(cls.parameters(), lr=0.001)\n",
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"base_model = BasicNN(\n",
@@ -148,7 +145,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"The base model can predict the outcome class index and the probabilities for an image, as shown below:"
"`BasicNN` offers methods like `predict` and `predict_prob`, which are used to predict the outcome class index and the probabilities for an image, respectively. As shown below:"
]
},
{
@@ -176,7 +173,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Then, we build an instance of `ABLModel`. The main function of `ABLModel` is to serialize data and provide a unified interface for different base machine learning models."
"However, base model built above are trained to make predictions on instance-level data, i.e., a single image, and can not directly utilize sample-level data, i.e., a pair of images. Therefore, we then wrap the base model into `ABLModel` which enables the learning part to train, test, and predict on sample-level data."
]
},
{
@@ -192,44 +189,63 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Logic Part"
"TODO: 示例展示ablmodel和base model的predict的不同"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# from abl.structures import ListData\n",
"# data_samples = ListData()\n",
"# data_samples.X = [list(torch.randn(2, 1, 28, 28)) for _ in range(3)]\n",
"\n",
"# model.predict(data_samples)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Building the Reasoning Part"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In the logic part, we first build a knowledge base."
"In the reasoning part, we first build a knowledge base which contain information on how to perform addition operations. We build it by creating a subclass of `KBBase`. In the derived subclass, we have to first initialize the `pseudo_label_list` parameter specifying list of possible pseudo labels, and then override the `logic_forward` function defining how to perform (deductive) reasoning."
]
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# Build knowledge base and reasoner\n",
"class AddKB(KBBase):\n",
" def __init__(self, pseudo_label_list):\n",
" def __init__(self, pseudo_label_list=list(range(10))):\n",
" super().__init__(pseudo_label_list)\n",
"\n",
" # Implement the deduction function\n",
" def logic_forward(self, nums):\n",
" return sum(nums)\n",
"\n",
"kb = AddKB(pseudo_label_list=list(range(10)))"
"kb = AddKB()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The knowledge base can perform logical reasoning. Below is an example of performing (deductive) reasoning:"
"The knowledge base can perform logical reasoning. Below is an example of performing (deductive) reasoning: # TODO: ABDUCTIVE REASONING"
]
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 9,
"metadata": {},
"outputs": [
{
@@ -250,16 +266,32 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Then, we create a reasoner. It can help minimize inconsistencies between the knowledge base and pseudo labels predicted by the learning part."
"Note: In addition to building a knowledge base based on `KBBase`, we can also establish a knowledge base with a ground KB using `GroundKB`, or a knowledge base implemented based on Prolog files using `PrologKB`. The corresponding code for these implementations can be found in the `main.py` file. Those interested are encouraged to examine it for further insights."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Then, we create a reasoner by instantiating the class ``Reasoner``. Due to the indeterminism of abductive reasoning, there could be multiple candidates compatible to the knowledge base. When this happens, reasoner can minimize inconsistencies between the knowledge base and pseudo labels predicted by the learning part, and then return only one candidate which has highest consistency."
]
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"reasoner = Reasoner(kb, dist_func=\"confidence\")"
"reasoner = Reasoner(kb)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note: During creating reasoner, the definition of \"consistency\" can be customized within the `dist_func` parameter. In the code above, we employ a consistency measurement based on confidence, which calculates the consistency between the data sample and candidates based on the confidence derived from the predicted probability. In `main.py`, we provide options for utilizing other forms of consistency measurement.\n",
"\n",
"Also, during process of inconsistency minimization, one can leverage [ZOOpt library](https://github.com/polixir/ZOOpt) for acceleration. Options for this are also available in `main.py`. Those interested are encouraged to explore these features."
]
},
{
@@ -267,19 +299,19 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Evaluation Metrics"
"## Building Evaluation Metrics"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Set up evaluation metrics. These metrics will be used to evaluate the model performance during training and testing."
"Next, we set up evaluation metrics. These metrics will be used to evaluate the model performance during training and testing. Specifically, we use `SymbolMetric` and `ReasoningMetric`, which are used to evaluate the accuracy of the machine learning model’s predictions and the accuracy of the final reasoning results, respectively."
]
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
@@ -293,12 +325,12 @@
"source": [
"## Bridge Learning and Reasoning\n",
"\n",
"Now, the last step is to bridge the learning and reasoning part."
"Now, the last step is to bridge the learning and reasoning part. We proceed this step by creating an instance of `SimpleBridge`."
]
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
@@ -309,7 +341,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Perform training and testing."
"Perform training and testing by invoking the `train` and `test` methods of `SimpleBridge`."
]
},
{
@@ -344,7 +376,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.18"
"version": "3.8.13"
},
"orig_nbformat": 4,
"vscode": {

+ 1
- 3
examples/mnist_add/requirements.txt View File

@@ -1,4 +1,2 @@
torch
torchvision
torchaudio
abl
matplotlib

Loading…
Cancel
Save