diff --git a/abl/reasoning/kb.py b/abl/reasoning/kb.py index cbde1ca..26068d9 100644 --- a/abl/reasoning/kb.py +++ b/abl/reasoning/kb.py @@ -43,8 +43,8 @@ class KBBase(ABC): Notes ----- - Users should inherit from this base class to build their own knowledge base. For the - user-build KB (an inherited subclass), it's only required for the user to provide the + Users should derive from this base class to build their own knowledge base. For the + user-build KB (a derived subclass), it's only required for the user to provide the `pseudo_label_list` and override the `logic_forward` function (specifying how to perform logical reasoning). After that, other operations (e.g. how to perform abductive reasoning) will be automatically set up. diff --git a/docs/Examples/MNISTAdd.rst b/docs/Examples/MNISTAdd.rst index 9c0696b..1ca46d4 100644 --- a/docs/Examples/MNISTAdd.rst +++ b/docs/Examples/MNISTAdd.rst @@ -1,49 +1,56 @@ MNIST Addition ============== -This example shows a simple implementation of MNIST Addition, which was -first introduced in `Manhaeve et al., -2018 `__. In this task, the inputs are -pairs of MNIST handwritten images, and the outputs are their sums. - -In Abductive Learning, we hope to first use learning part to map the -input images to their digits (we call it pseudo labels), and then use -reasoning part to calculate the summation of these pseudo labels to get -the final result. +In this example, we show an implementation of `MNIST +Addition `_. In this task, pairs of +MNIST handwritten images and their sums are given, alongwith a domain +knowledge base which contain information on how to perform addition +operations. Our objective is to input a pair of handwritten images and +accurately determine their sum. + +Intuitively, we first use a machine learning model (learning part) to +convert the input images to digits (we call them pseudo labels), and +then use the knowledge base (reasoning part) to calculate the sum of +these digits. Since we do not have ground-truth of the digits, the +reasoning part will leverage domain knowledge and revise the initial +digits yielded by the learning part into results derived from abductive +reasoning. This process enables us to further refine and retrain the +machine learning model. .. code:: ipython3 + # Import necessary libraries and modules import os.path as osp - import torch import torch.nn as nn import matplotlib.pyplot as plt - - from abl.bridge import SimpleBridge - from abl.evaluation import ReasoningMetric, SymbolMetric + from examples.mnist_add.datasets import get_dataset + from examples.models.nn import LeNet5 from abl.learning import ABLModel, BasicNN from abl.reasoning import KBBase, Reasoner + from abl.evaluation import ReasoningMetric, SymbolMetric from abl.utils import ABLLogger, print_log - from examples.mnist_add.datasets import get_mnist_add - from examples.models.nn import LeNet5 + from abl.bridge import SimpleBridge -Load Datasets -------------- +Working with Data +----------------- -First, we get training and testing data: +First, we get the training and testing datasets: .. code:: ipython3 - train_data = get_mnist_add(train=True, get_pseudo_label=True) - test_data = get_mnist_add(train=False, get_pseudo_label=True) + train_data = get_dataset(train=True, get_pseudo_label=True) + test_data = get_dataset(train=False, get_pseudo_label=True) -The datasets are illustrated as follows: +Both datasets contain several data examples. In each data example, we +have three components: X (a pair of images), gt_pseudo_label (a pair of +corresponding ground truth digits, i.e., pseudo labels), and Y (their sum). The datasets are illustrated +as follows. .. code:: ipython3 print(f"There are {len(train_data[0])} data examples in the training set and {len(test_data[0])} data examples in the test set") - print(f"Each of the data example has {len(train_data)} components: X, gt_pseudo_label, and Y.") - print("As an illustration, in the First data example of the training set, we have:") + print("As an illustration, in the first data example of the training set, we have:") print(f"X ({len(train_data[0][0])} images):") plt.subplot(1,2,1) plt.axis('off') @@ -56,36 +63,39 @@ The datasets are illustrated as follows: print(f"Y (their sum result): {train_data[2][0]}") Out: - .. code:: none - :class: code-out - - There are 30000 data examples in the training set and 5000 data examples in the test set - Each of the data example has 3 components: X, gt_pseudo_label, and Y. - As an illustration, in the First data example of the training set, we have: - X (2 images): - - .. image:: ../img/mnist_add_datasets.png - :width: 400px - - .. code:: none - :class: code-out + .. code:: none + :class: code-out + + There are 30000 data examples in the training set and 5000 data examples in the test set + As an illustration, in the first data example of the training set, we have: + X (2 images): + + .. image:: ../img/mnist_add_datasets.png + :width: 400px - gt_pseudo_label (2 ground truth pseudo label): 7, 5 - Y (their sum result): 12 + .. code:: none + :class: code-out + gt_pseudo_label (2 ground truth pseudo label): 7, 5 + Y (their sum result): 12 + -Learning Part -------------- +Building the Learning Part +-------------------------- -First, we build the basic learning model. We use a simple `LeNet neural -network `__ to complete this task. +To build the learning part, we need to first build a base machine +learning model. We use a simple `LeNet-5 neural +network `__ to complete this task, +and encapsulate it within a ``BasicNN`` object to create the base model. +``BasicNN`` is a class that encapsulates a PyTorch model, transforming +it into a base model with an sklearn-style interface. .. code:: ipython3 cls = LeNet5(num_classes=10) loss_fn = nn.CrossEntropyLoss() - optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99)) + optimizer = torch.optim.Adam(cls.parameters(), lr=0.001) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") base_model = BasicNN( @@ -97,8 +107,9 @@ network `__ to complete this task. num_epochs=1, ) -The base model can predict the outcome class index and the probabilities -for an image, as shown below: +``BasicNN`` offers methods like ``predict`` and ``predict_prob``, which +are used to predict the outcome class index and the probabilities for an +image, respectively. As shown below: .. code:: ipython3 @@ -107,42 +118,59 @@ for an image, as shown below: pred_prob = base_model.predict_proba(X=[torch.randn(1, 28, 28).to(device) for _ in range(32)]) print(f"Shape of pred_prob for a batch of 32 samples: {pred_prob.shape}") + Out: - .. code:: none - :class: code-out + .. code:: none + :class: code-out - Shape of pred_idx for a batch of 32 samples: (32,) - Shape of pred_prob for a batch of 32 samples: (32, 10) + Shape of pred_idx for a batch of 32 samples: (32,) + Shape of pred_prob for a batch of 32 samples: (32, 10) -Then, we build an instance of ``ABLModel``. The main function of -``ABLModel`` is to serialize data and provide a unified interface for -different base machine learning models. +However, base model built above are trained to make predictions on +instance-level data, i.e., a single image, and can not directly utilize +sample-level data, i.e., a pair of images. Therefore, we then wrap the +base model into ``ABLModel`` which enables the learning part to train, +test, and predict on sample-level data. .. code:: ipython3 model = ABLModel(base_model) -Logic Part ----------- +TODO: 示例展示ablmodel和base model的predict的不同 + +.. code:: ipython3 + + # from abl.structures import ListData + # data_samples = ListData() + # data_samples.X = [list(torch.randn(2, 1, 28, 28)) for _ in range(3)] + + # model.predict(data_samples) + +Building the Reasoning Part +--------------------------- -In the logic part, we first build a knowledge base. +In the reasoning part, we first build a knowledge base which contain +information on how to perform addition operations. We build it by +creating a subclass of ``KBBase``. In the derived subclass, we have to +first initialize the ``pseudo_label_list`` parameter specifying list of +possible pseudo labels, and then override the ``logic_forward`` function +defining how to perform (deductive) reasoning. .. code:: ipython3 - # Build knowledge base and reasoner class AddKB(KBBase): - def __init__(self, pseudo_label_list): + def __init__(self, pseudo_label_list=list(range(10))): super().__init__(pseudo_label_list) # Implement the deduction function def logic_forward(self, nums): return sum(nums) - kb = AddKB(pseudo_label_list=list(range(10))) + kb = AddKB() The knowledge base can perform logical reasoning. Below is an example of -performing (deductive) reasoning: +performing (deductive) reasoning: # TODO: ABDUCTIVE REASONING .. code:: ipython3 @@ -150,25 +178,57 @@ performing (deductive) reasoning: reasoning_result = kb.logic_forward(pseudo_label_sample) print(f"Reasoning result of pseudo label sample {pseudo_label_sample} is {reasoning_result}.") + Out: - .. code:: none - :class: code-out + .. code:: none + :class: code-out - Reasoning result of pseudo label sample [1, 2] is 3. + Reasoning result of pseudo label sample [1, 2] is 3. -Then, we create a reasoner. It can help minimize inconsistencies between -the knowledge base and pseudo labels predicted by the learning part. +.. note:: + + In addition to building a knowledge base based on ``KBBase``, we + can also establish a knowledge base with a ground KB using ``GroundKB``, + or a knowledge base implemented based on Prolog files using + ``PrologKB``. The corresponding code for these implementations can be + found in the ``examples/mnist_add/main.py`` file. Those interested are encouraged to + examine it for further insights. + +Then, we create a reasoner by instantiating the class ``Reasoner``. Due +to the indeterminism of abductive reasoning, there could be multiple +candidates compatible to the knowledge base. When this happens, reasoner +can minimize inconsistencies between the knowledge base and pseudo +labels predicted by the learning part, and then return only one +candidate which has highest consistency. .. code:: ipython3 - reasoner = Reasoner(kb, dist_func="confidence") + reasoner = Reasoner(kb) + +.. note:: + + During creating reasoner, the definition of “consistency” can be + customized within the ``dist_func`` parameter. In the code above, we + employ a consistency measurement based on confidence, which calculates + the consistency between the data sample and candidates based on the + confidence derived from the predicted probability. In ``examples/mnist_add/main.py``, we + provide options for utilizing other forms of consistency measurement. -Evaluation Metrics ------------------- + Also, during process of inconsistency minimization, one can leverage + `ZOOpt library `__ for acceleration. + Options for this are also available in ``examples/mnist_add/main.py``. Those interested are + encouraged to explore these features. -Set up evaluation metrics. These metrics will be used to evaluate the -model performance during training and testing. +Building Evaluation Metrics +--------------------------- + +Next, we set up evaluation metrics. These metrics will be used to +evaluate the model performance during training and testing. +Specifically, we use ``SymbolMetric`` and ``ReasoningMetric``, which are +used to evaluate the accuracy of the machine learning model’s +predictions and the accuracy of the final reasoning results, +respectively. .. code:: ipython3 @@ -177,23 +237,23 @@ model performance during training and testing. Bridge Learning and Reasoning ----------------------------- -Now, the last step is to bridge the learning and reasoning part. +Now, the last step is to bridge the learning and reasoning part. We +proceed this step by creating an instance of ``SimpleBridge``. .. code:: ipython3 bridge = SimpleBridge(model, reasoner, metric_list) -Perform training and testing. +Perform training and testing by invoking the ``train`` and ``test`` +methods of ``SimpleBridge``. .. code:: ipython3 # Build logger print_log("Abductive Learning on the MNIST Addition example.", logger="current") - - # Retrieve the directory of the Log file and define the directory for saving the model weights. log_dir = ABLLogger.get_current_instance().log_dir weights_dir = osp.join(log_dir, "weights") - + bridge.train(train_data, loops=5, segment_size=1/3, save_interval=1, save_dir=weights_dir) bridge.test(test_data) @@ -254,4 +314,4 @@ Out: abl - INFO - Checkpoints will be saved to log_dir/weights/model_checkpoint_loop_5.pth abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.002 mnist_add/reasoning_accuracy: 0.482 -More concrete examples are available in ``examples/mnist_add`` folder. \ No newline at end of file +More concrete examples are available in ``examples/mnist_add/main.py`` and ``examples/mnist_add/mnist_add.ipynb``. \ No newline at end of file diff --git a/docs/Intro/Basics.rst b/docs/Intro/Basics.rst index b68edd8..f8b81aa 100644 --- a/docs/Intro/Basics.rst +++ b/docs/Intro/Basics.rst @@ -22,7 +22,7 @@ AI: data, models, and knowledge. .. image:: ../img/ABL-Package.png **Data** module manages the storage, operation, and evaluation of data. -It first features class ``ListData`` (inherited from base class +It first features class ``ListData`` (derived from base class ``BaseDataElement``), which defines the data structures used in Abductive Learning, and comprises common data operations like insertion, deletion, retrieval, slicing, etc. Additionally, a series of Evaluation @@ -46,7 +46,7 @@ responsible for minimizing the inconsistency between the knowledge base and learning models. Finally, the integration of these three modules occurs through -**Bridge** module, which features class ``SimpleBridge`` (inherited from base +**Bridge** module, which features class ``SimpleBridge`` (derived from base class ``BaseBridge``). Bridge module synthesize data, learning, and reasoning, and facilitate the training and testing of the entire Abductive Learning framework. diff --git a/docs/Intro/Bridge.rst b/docs/Intro/Bridge.rst index 0da1962..caee039 100644 --- a/docs/Intro/Bridge.rst +++ b/docs/Intro/Bridge.rst @@ -14,6 +14,7 @@ In this section, we will look at how to bridge learning and reasoning parts to t .. code:: python + # Import necessary modules from abl.bridge import BaseBridge, SimpleBridge ``BaseBridge`` is an abstract class with the following initialization parameters: diff --git a/docs/Intro/Datasets.rst b/docs/Intro/Datasets.rst index 8952559..ae3b701 100644 --- a/docs/Intro/Datasets.rst +++ b/docs/Intro/Datasets.rst @@ -14,6 +14,7 @@ In this section, we will look at the datasets and data structures in ABL-Package .. code:: python + # Import necessary libraries and modules import torch from abl.structures import ListData diff --git a/docs/Intro/Evaluation.rst b/docs/Intro/Evaluation.rst index 9c835c1..0985487 100644 --- a/docs/Intro/Evaluation.rst +++ b/docs/Intro/Evaluation.rst @@ -10,19 +10,22 @@ Evaluation Metrics ================== -In this section, we will look at how to build evaluation metrics. ABL-Package seperates the evaluation process from model training and testing as an independent class, ``BaseMetric``. The training and testing processes are implemented in the ``BaseBridge`` class, so metrics are used by this class and its sub-classes. After building a ``bridge`` with a list of ``BaseMetric`` instances, these metrics will be used by the ``bridge.valid`` method to evaluate the model performance during training and testing. +In this section, we will look at how to build evaluation metrics. .. code:: python + # Import necessary modules from abl.evaluation import BaseMetric, SymbolMetric, ReasoningMetric +ABL-Package seperates the evaluation process from model training and testing as an independent class, ``BaseMetric``. The training and testing processes are implemented in the ``BaseBridge`` class, so metrics are used by this class and its sub-classes. After building a ``bridge`` with a list of ``BaseMetric`` instances, these metrics will be used by the ``bridge.valid`` method to evaluate the model performance during training and testing. + To customize our own metrics, we need to inherit from ``BaseMetric`` and implement the ``process`` and ``compute_metrics`` methods. - The ``process`` method accepts a batch of model prediction and saves the information to ``self.results`` property after processing this batch. - The ``compute_metrics`` method uses all the information saved in ``self.results`` to calculate and return a dict that holds the evaluation results. Besides, we can assign a ``str`` to the ``prefix`` argument of the ``__init__`` method. This string is automatically prefixed to the output metric names. For example, if we set ``prefix="mnist_add"``, the output metric name will be ``character_accuracy``. -We provide two basic metrics, namely ``SymbolMetric`` and ``ReasoningMetric``, which are used to evaluate the accuracy of the machine learning model's predictions and the accuracy of the ``logic_forward`` results, respectively. Using ``SymbolMetric`` as an example, the following code shows how to implement a custom metrics. +We provide two basic metrics, namely ``SymbolMetric`` and ``ReasoningMetric``, which are used to evaluate the accuracy of the machine learning model's predictions and the accuracy of the final reasoning results, respectively. Using ``SymbolMetric`` as an example, the following code shows how to implement a custom metrics. .. code:: python diff --git a/docs/Intro/Learning.rst b/docs/Intro/Learning.rst index bc248e0..3111d3d 100644 --- a/docs/Intro/Learning.rst +++ b/docs/Intro/Learning.rst @@ -14,6 +14,7 @@ In this section, we will look at how to build the learning part. In ABL-Package, .. code:: python + # Import necessary libraries and modules import sklearn import torchvision from abl.learning import BasicNN, ABLModel diff --git a/docs/Intro/Quick-Start.rst b/docs/Intro/Quick-Start.rst index e44dfde..49d674e 100644 --- a/docs/Intro/Quick-Start.rst +++ b/docs/Intro/Quick-Start.rst @@ -9,7 +9,7 @@ Quick Start =========== -We use the MNIST Addition task as a quick start example. In this task, the inputs are pairs of MNIST handwritten images, and the outputs are their sums. Refer to the links in each section to dive deeper. +We use the MNIST Addition task as a quick start example. In this task, pairs of MNIST handwritten images and their sums are given, alongwith a domain knowledge base which contain information on how to perform addition operations. Our objective is to input a pair of handwritten images and accurately determine their sum. Refer to the links in each section to dive deeper. Working with Data ----------------- @@ -90,8 +90,11 @@ function specifying how to perform (deductive) reasoning. Then, we create a reasoner by instantiating the class ``Reasoner`` and passing the knowledge base as an parameter. -The reasoner can be used to minimize inconsistencies between the -knowledge base and the prediction from the learning part. +Due to the indeterminism of abductive reasoning, there could +be multiple candidates compatible to the knowledge base. +When this happens, reasoner can minimize inconsistencies between +the knowledge base and pseudo labels predicted by the learning part, +and then return only one candidate which has highest consistency. .. code:: python diff --git a/docs/Intro/Reasoning.rst b/docs/Intro/Reasoning.rst index 0314843..73f7e0d 100644 --- a/docs/Intro/Reasoning.rst +++ b/docs/Intro/Reasoning.rst @@ -22,12 +22,13 @@ In ABL-Package, building the reasoning part involves two steps: .. code:: python + # Import necessary modules from abl.reasoning import KBBase, GroundKB, PrologKB, Reasoner Building a knowledge base ------------------------- -Generally, we can create a subclass inherited from ``KBBase`` to build our own +Generally, we can create a subclass derived from ``KBBase`` to build our own knowledge base. In addition, ABL-Package also offers several predefined subclasses of ``KBBase`` (e.g., ``PrologKB`` and ``GroundKB``), which we can utilize to build our knowledge base more conveniently. @@ -35,7 +36,7 @@ which we can utilize to build our knowledge base more conveniently. Building a knowledge base from `KBBase` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -For the user-built KB from `KBBase` (an inherited subclass), it's only +For the user-built KB from ``KBBase`` (a derived subclass), it's only required to pass the ``pseudo_label_list`` parameter in the ``__init__`` function and override the ``logic_forward`` function: @@ -184,7 +185,7 @@ As an example, the ``GKB_len_list`` for MNIST Addition should be ``[2]``, since all pseudo labels in the example consist of two digits. Therefore, the construction of KB with GKB (``add_ground_kb``) of MNIST Addition would be as follows. As mentioned, the difference between this and the previously -built ``add_kb`` lies only in the base class from which it is inherited +built ``add_kb`` lies only in the base class from which it is derived and whether an extra parameter ``GKB_len_list`` is passed. .. code:: python diff --git a/docs/Overview/Abductive-Learning.rst b/docs/Overview/Abductive-Learning.rst index b0d3d5e..a4e90d1 100644 --- a/docs/Overview/Abductive-Learning.rst +++ b/docs/Overview/Abductive-Learning.rst @@ -51,7 +51,7 @@ The following figure illustrates this process: We can observe that in the above figure, the left half involves machine learning, while the right half involves logical reasoning. Thus, the -entire abductive learning process is a continuous cycle of machine +entire Abductive Learning process is a continuous cycle of machine learning and logical reasoning. This effectively forms a paradigm that is dual-driven by both data and domain knowledge, integrating and balancing the use of machine learning and logical reasoning in a unified diff --git a/examples/hwf/datasets/README.md b/examples/hwf/datasets/README.md deleted file mode 100644 index b8f3047..0000000 --- a/examples/hwf/datasets/README.md +++ /dev/null @@ -1,4 +0,0 @@ -Download the Handwritten Formula Recognition dataset from [google drive](https://drive.google.com/file/d/1G07kw-wK-rqbg_85tuB7FNfA49q8lvoy/view?usp=sharing) to this folder and unzip it: -``` -unzip HWF.zip -``` diff --git a/examples/hwf/datasets/__init__.py b/examples/hwf/datasets/__init__.py new file mode 100644 index 0000000..a2e06bd --- /dev/null +++ b/examples/hwf/datasets/__init__.py @@ -0,0 +1,3 @@ +from .get_dataset import get_dataset + +__all__ = ["get_dataset"] \ No newline at end of file diff --git a/examples/hwf/datasets/get_dataset.py b/examples/hwf/datasets/get_dataset.py new file mode 100644 index 0000000..3c0029b --- /dev/null +++ b/examples/hwf/datasets/get_dataset.py @@ -0,0 +1,61 @@ +import json +import os +import gdown +import zipfile + +from PIL import Image +from torchvision.transforms import transforms + +CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) + +img_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1,))]) + +def download_and_unzip(url, zip_file_name): + try: + gdown.download(url, zip_file_name) + with zipfile.ZipFile(zip_file_name, 'r') as zip_ref: + zip_ref.extractall() + os.remove(zip_file_name) + except Exception as e: + if os.path.exists(zip_file_name): + os.remove(zip_file_name) + raise Exception(f"An error occurred during download or unzip: {e}. Instead, you can download the dataset from {url} and unzip it in './datasets' folder") + +def get_dataset(train=True, get_pseudo_label=False): + data_dir = CURRENT_DIR + '/data' + url = 'https://drive.google.com/u/0/uc?id=1G07kw-wK-rqbg_85tuB7FNfA49q8lvoy&export=download' + + if not os.path.exists(data_dir): + print("Dataset not exist, downloading it...") + download_and_unzip(url, 'HWF.zip') + print("Download and extraction complete.") + + if train: + file = os.path.join(data_dir, "expr_train.json") + else: + file = os.path.join(data_dir, "expr_test.json") + + X = [] + Z = [] if get_pseudo_label else None + Y = [] + img_dir = os.path.join(CURRENT_DIR, "data/Handwritten_Math_Symbols/") + with open(file) as f: + data = json.load(f) + for idx in range(len(data)): + imgs = [] + if get_pseudo_label: + imgs_pseudo_label = [] + for img_path in data[idx]["img_paths"]: + img = Image.open(img_dir + img_path).convert("L") + img = img_transform(img) + imgs.append(img) + if get_pseudo_label: + imgs_pseudo_label.append(img_path.split("/")[0]) + X.append(imgs) + if get_pseudo_label: + Z.append(imgs_pseudo_label) + Y.append(data[idx]["res"]) + + return X, Z, Y + +get_dataset() diff --git a/examples/hwf/datasets/get_hwf.py b/examples/hwf/datasets/get_hwf.py deleted file mode 100644 index 608306a..0000000 --- a/examples/hwf/datasets/get_hwf.py +++ /dev/null @@ -1,45 +0,0 @@ -import json -import os - -from PIL import Image -from torchvision.transforms import transforms - -CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) - -img_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1,))]) - - -def get_data(file, get_pseudo_label): - X, Y = [], [] - if get_pseudo_label: - Z = [] - img_dir = os.path.join(CURRENT_DIR, "data/Handwritten_Math_Symbols/") - with open(file) as f: - data = json.load(f) - for idx in range(len(data)): - imgs = [] - imgs_pseudo_label = [] - for img_path in data[idx]["img_paths"]: - img = Image.open(img_dir + img_path).convert("L") - img = img_transform(img) - imgs.append(img) - if get_pseudo_label: - imgs_pseudo_label.append(img_path.split("/")[0]) - X.append(imgs) - if get_pseudo_label: - Z.append(imgs_pseudo_label) - Y.append(data[idx]["res"]) - - if get_pseudo_label: - return X, Z, Y - else: - return X, None, Y - - -def get_hwf(train=True, get_gt_pseudo_label=False): - if train: - file = os.path.join(CURRENT_DIR, "data/expr_train.json") - else: - file = os.path.join(CURRENT_DIR, "data/expr_test.json") - - return get_data(file, get_gt_pseudo_label) diff --git a/examples/hwf/datasets/test.ipynb b/examples/hwf/datasets/test.ipynb new file mode 100644 index 0000000..f72a9e0 --- /dev/null +++ b/examples/hwf/datasets/test.ipynb @@ -0,0 +1,70 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name '__file__' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Input \u001b[0;32mIn [1]\u001b[0m, in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mPIL\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Image\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorchvision\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mtransforms\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m transforms\n\u001b[0;32m---> 10\u001b[0m CURRENT_DIR \u001b[38;5;241m=\u001b[39m os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mabspath(os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mdirname(\u001b[38;5;18;43m__file__\u001b[39;49m))\n\u001b[1;32m 12\u001b[0m img_transform \u001b[38;5;241m=\u001b[39m transforms\u001b[38;5;241m.\u001b[39mCompose([transforms\u001b[38;5;241m.\u001b[39mToTensor(), transforms\u001b[38;5;241m.\u001b[39mNormalize((\u001b[38;5;241m0.5\u001b[39m,), (\u001b[38;5;241m1\u001b[39m,))])\n\u001b[1;32m 14\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mrequests\u001b[39;00m\n", + "\u001b[0;31mNameError\u001b[0m: name '__file__' is not defined" + ] + } + ], + "source": [ + "import json\n", + "import os\n", + "import requests\n", + "import os\n", + "import zipfile\n", + "\n", + "from PIL import Image\n", + "from torchvision.transforms import transforms\n", + "\n", + "CURRENT_DIR = os.path.abspath(os.path.dirname(__file__))\n", + "\n", + "img_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1,))])\n", + "\n", + "import requests\n", + "import os\n", + "import zipfile\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "abl", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/hwf/hwf_example.ipynb b/examples/hwf/hwf_example.ipynb index de39a01..fd2a62a 100644 --- a/examples/hwf/hwf_example.ipynb +++ b/examples/hwf/hwf_example.ipynb @@ -1,8 +1,26 @@ { "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Handwritten Formula (HWF)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This example shows a simple implementation of Handwritten Formula, which was first introduced in [Li et al., 2020](https://arxiv.org/abs/2006.06649). In this task, the inputs are images of decimal formulas, and the outputs are their computed results.\n", + "\n", + "In Abductive Learning, we hope to first use learning part to map the input images to their symbols (we call them pseudo labels), and then use reasoning part to calculate the summation of these pseudo labels to get the final result.\n", + "\n", + "The HWF dataset ontains images of decimal formulas and their computed results. " + ] + }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -18,14 +36,22 @@ "from abl.utils import ABLLogger, print_log\n", "\n", "from examples.models.nn import SymbolNet\n", - "from datasets.get_hwf import get_hwf" + "from examples.hwf.datasets.get_dataset import get_dataset" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "12/18 12:48:19 - abl - INFO - Abductive Learning on the HWF example.\n" + ] + } + ], "source": [ "# Initialize logger and print basic information\n", "print_log(\"Abductive Learning on the HWF example.\", logger=\"current\")\n", @@ -159,13 +185,26 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "FileNotFoundError", + "evalue": "[Errno 2] 没有那个文件或目录: '/home/huwc/ABL-Package/examples/hwf/datasets/data/expr_train.json'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", + "Input \u001b[0;32mIn [4]\u001b[0m, in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# Get training and testing data\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m train_data \u001b[38;5;241m=\u001b[39m \u001b[43mget_dataset\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrain\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mget_pseudo_label\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m 3\u001b[0m test_data \u001b[38;5;241m=\u001b[39m get_dataset(train\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m, get_pseudo_label\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n", + "File \u001b[0;32m~/ABL-Package/examples/hwf/datasets/get_dataset.py:21\u001b[0m, in \u001b[0;36mget_dataset\u001b[0;34m(train, get_pseudo_label)\u001b[0m\n\u001b[1;32m 19\u001b[0m Y \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m 20\u001b[0m img_dir \u001b[38;5;241m=\u001b[39m os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(CURRENT_DIR, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdata/Handwritten_Math_Symbols/\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m---> 21\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28;43mopen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mfile\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mas\u001b[39;00m f:\n\u001b[1;32m 22\u001b[0m data \u001b[38;5;241m=\u001b[39m json\u001b[38;5;241m.\u001b[39mload(f)\n\u001b[1;32m 23\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m idx \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mlen\u001b[39m(data)):\n", + "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] 没有那个文件或目录: '/home/huwc/ABL-Package/examples/hwf/datasets/data/expr_train.json'" + ] + } + ], "source": [ "# Get training and testing data\n", - "train_data = get_hwf(train=True, get_gt_pseudo_label=True)\n", - "test_data = get_hwf(train=False, get_gt_pseudo_label=True)" + "train_data = get_dataset(train=True, get_pseudo_label=True)\n", + "test_data = get_dataset(train=False, get_pseudo_label=True)" ] }, { @@ -220,7 +259,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.18" + "version": "3.8.13" }, "orig_nbformat": 4, "vscode": { diff --git a/examples/hwf/main.py b/examples/hwf/main.py new file mode 100644 index 0000000..fe109da --- /dev/null +++ b/examples/hwf/main.py @@ -0,0 +1,114 @@ +# %% +import torch +import numpy as np +import torch.nn as nn +import os.path as osp + +from abl.reasoning import Reasoner, KBBase +from abl.learning import BasicNN, ABLModel +from abl.bridge import SimpleBridge +from abl.evaluation import SymbolMetric, ReasoningMetric +from abl.utils import ABLLogger, print_log + +from examples.models.nn import SymbolNet +from examples.hwf.datasets.get_dataset import get_hwf + +# %% +# Initialize logger and print basic information +print_log("Abductive Learning on the HWF example.", logger="current") + +# Retrieve the directory of the Log file and define the directory for saving the model weights. +log_dir = ABLLogger.get_current_instance().log_dir +weights_dir = osp.join(log_dir, "weights") + +# %% [markdown] +# ### Logic Part + +# %% +# Initialize knowledge base and reasoner +class HWF_KB(KBBase): + def _valid_candidate(self, formula): + if len(formula) % 2 == 0: + return False + for i in range(len(formula)): + if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]: + return False + if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]: + return False + return True + + def logic_forward(self, formula): + if not self._valid_candidate(formula): + return np.inf + mapping = {str(i): str(i) for i in range(1, 10)} + mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"}) + formula = [mapping[f] for f in formula] + return eval("".join(formula)) + + +kb = HWF_KB( + pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", "+", "-", "times", "div"], + max_err=1e-10, + use_cache=False, +) +reasoner = Reasoner(kb, dist_func="confidence") + +# %% [markdown] +# ### Machine Learning Part + +# %% +# Initialize necessary component for machine learning part +cls = SymbolNet(num_classes=len(kb.pseudo_label_list), image_size=(45, 45, 1)) +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +loss_fn = nn.CrossEntropyLoss() +optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99)) + +# %% +# Initialize BasicNN +# The function of BasicNN is to wrap NN models into the form of an sklearn estimator +base_model = BasicNN( + model=cls, + loss_fn=loss_fn, + optimizer=optimizer, + device=device, + save_interval=1, + save_dir=weights_dir, + batch_size=128, + num_epochs=3, +) + +# %% +# Initialize ABL model +# The main function of the ABL model is to serialize data and +# provide a unified interface for different machine learning models +model = ABLModel(base_model) + +# %% [markdown] +# ### Metric + +# %% +# Add metric +metric_list = [SymbolMetric(prefix="hwf"), ReasoningMetric(kb=kb, prefix="hwf")] + +# %% [markdown] +# ### Dataset + +# %% +# Get training and testing data +train_data = get_hwf(train=True, get_pseudo_label=True) +test_data = get_hwf(train=False, get_pseudo_label=True) + +# %% [markdown] +# ### Bridge Machine Learning and Logic Reasoning + +# %% +bridge = SimpleBridge(model=model, reasoner=reasoner, metric_list=metric_list) + +# %% [markdown] +# ### Train and Test + +# %% +bridge.train(train_data, train_data, loops=3, segment_size=1000, save_interval=1, save_dir=weights_dir) +bridge.test(test_data) + + diff --git a/examples/hwf/requirements.txt b/examples/hwf/requirements.txt new file mode 100644 index 0000000..023006c --- /dev/null +++ b/examples/hwf/requirements.txt @@ -0,0 +1 @@ +gdown \ No newline at end of file diff --git a/examples/mnist_add/README.md b/examples/mnist_add/README.md index 3196994..cdf92eb 100644 --- a/examples/mnist_add/README.md +++ b/examples/mnist_add/README.md @@ -4,16 +4,15 @@ This example shows a simple implementation of [MNIST Addition](https://link) tas ## Run -``` -bash +```bash pip install -r requirements.txt python main.py ``` ## Usage -``` -usage: test.py [-h] [--no-cuda] [--epochs EPOCHS] [--lr LR] +```bash +usage: main.py [-h] [--no-cuda] [--epochs EPOCHS] [--lr LR] [--weight-decay WEIGHT_DECAY] [--batch-size BATCH_SIZE] [--loops LOOPS] [--segment_size SEGMENT_SIZE] [--save_interval SAVE_INTERVAL] [--max-revision MAX_REVISION] @@ -34,7 +33,7 @@ optional arguments: batch size (default : 32) --loops LOOPS number of loop iterations (default : 5) --segment_size SEGMENT_SIZE - number of loop iterations (default : 1/3) + segment size (default : 1/3) --save_interval SAVE_INTERVAL save interval (default : 1) --max-revision MAX_REVISION diff --git a/examples/mnist_add/datasets/__init__.py b/examples/mnist_add/datasets/__init__.py index ecec715..a2e06bd 100644 --- a/examples/mnist_add/datasets/__init__.py +++ b/examples/mnist_add/datasets/__init__.py @@ -1,3 +1,3 @@ -from .get_mnist_add import get_mnist_add +from .get_dataset import get_dataset -__all__ = ["get_mnist_add"] \ No newline at end of file +__all__ = ["get_dataset"] \ No newline at end of file diff --git a/examples/mnist_add/datasets/get_mnist_add.py b/examples/mnist_add/datasets/get_dataset.py similarity index 53% rename from examples/mnist_add/datasets/get_mnist_add.py rename to examples/mnist_add/datasets/get_dataset.py index 553f187..0d1c4b9 100644 --- a/examples/mnist_add/datasets/get_mnist_add.py +++ b/examples/mnist_add/datasets/get_dataset.py @@ -5,25 +5,7 @@ from torchvision.transforms import transforms CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) -def get_data(file, img_dataset, get_pseudo_label): - X = [] - if get_pseudo_label: - Z = [] - Y = [] - with open(file) as f: - for line in f: - line = line.strip().split(" ") - X.append([img_dataset[int(line[0])][0], img_dataset[int(line[1])][0]]) - if get_pseudo_label: - Z.append([img_dataset[int(line[0])][1], img_dataset[int(line[1])][1]]) - Y.append(int(line[2])) - - if get_pseudo_label: - return X, Z, Y - else: - return X, None, Y - -def get_mnist_add(train=True, get_pseudo_label=False): +def get_dataset(train=True, get_pseudo_label=False): transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] ) @@ -35,4 +17,14 @@ def get_mnist_add(train=True, get_pseudo_label=False): else: file = os.path.join(CURRENT_DIR, "test_data.txt") - return get_data(file, img_dataset, get_pseudo_label) + X = [] + Z = [] if get_pseudo_label else None + Y = [] + with open(file) as f: + for line in f: + x1, x2, y = map(int, line.strip().split(" ")) + X.append([img_dataset[x1][0], img_dataset[x2][0]]) + if get_pseudo_label: + Z.append([img_dataset[x1][1], img_dataset[x2][1]]) + Y.append(y) + return X, Z, Y diff --git a/examples/mnist_add/main.py b/examples/mnist_add/main.py index 504a146..72f10fe 100644 --- a/examples/mnist_add/main.py +++ b/examples/mnist_add/main.py @@ -5,13 +5,13 @@ import argparse import torch from torch import nn -from abl.bridge import SimpleBridge -from abl.evaluation import ReasoningMetric, SymbolMetric +from examples.mnist_add.datasets import get_dataset +from examples.models.nn import LeNet5 from abl.learning import ABLModel, BasicNN from abl.reasoning import KBBase, GroundKB, PrologKB, Reasoner +from abl.evaluation import ReasoningMetric, SymbolMetric from abl.utils import ABLLogger, print_log -from examples.mnist_add.datasets import get_mnist_add -from examples.models.nn import LeNet5 +from abl.bridge import SimpleBridge class AddKB(KBBase): def __init__(self, pseudo_label_list=list(range(10))): @@ -42,7 +42,7 @@ def main(): parser.add_argument('--loops', type=int, default=5, help='number of loop iterations (default : 5)') parser.add_argument('--segment_size', type=int or float, default=1/3, - help='number of loop iterations (default : 1/3)') + help='segment size (default : 1/3)') parser.add_argument('--save_interval', type=int, default=1, help='save interval (default : 1)') parser.add_argument('--max-revision', type=int or float, default=-1, @@ -56,8 +56,12 @@ def main(): help='use GroundKB (default: False)') args = parser.parse_args() + + ### Working with Data + train_data = get_dataset(train=True, get_pseudo_label=True) + test_data = get_dataset(train=False, get_pseudo_label=True) - ### Learning Part + ### Building the Learning Part # Build necessary components for BasicNN cls = LeNet5(num_classes=10) loss_fn = nn.CrossEntropyLoss() @@ -66,7 +70,6 @@ def main(): device = torch.device("cuda" if use_cuda else "cpu") # Build BasicNN - # The function of BasicNN is to wrap NN models into the form of an sklearn estimator base_model = BasicNN( cls, loss_fn, @@ -77,27 +80,24 @@ def main(): ) # Build ABLModel - # The main function of the ABL model is to serialize data and - # provide a unified interface for different machine learning models model = ABLModel(base_model) + ### Building the Reasoning Part + # Build knowledge base if args.prolog: kb = PrologKB(pseudo_label_list=list(range(10)), pl_file="add.pl") elif args.ground: kb = AddGroundKB() else: kb = AddKB() - reasoner = Reasoner(kb, dist_func="confidence", max_revision=args.max_revision, require_more_revision=args.require_more_revision) - - ### Evaluation Metrics - # Get training and testing data - train_data = get_mnist_add(train=True, get_pseudo_label=True) - test_data = get_mnist_add(train=False, get_pseudo_label=True) + + # Create reasoner + reasoner = Reasoner(kb, max_revision=args.max_revision, require_more_revision=args.require_more_revision) - # Set up metrics + ### Building Evaluation Metrics metric_list = [SymbolMetric(prefix="mnist_add"), ReasoningMetric(kb=kb, prefix="mnist_add")] - ### Bridge Machine Learning and Logic Reasoning + ### Bridge Learning and Reasoning bridge = SimpleBridge(model, reasoner, metric_list) # Build logger diff --git a/examples/mnist_add/mnist_add_example.ipynb b/examples/mnist_add/mnist_add.ipynb similarity index 64% rename from examples/mnist_add/mnist_add_example.ipynb rename to examples/mnist_add/mnist_add.ipynb index 3f57b14..ff9bfaa 100644 --- a/examples/mnist_add/mnist_add_example.ipynb +++ b/examples/mnist_add/mnist_add.ipynb @@ -6,9 +6,9 @@ "source": [ "# MNIST Addition\n", "\n", - "This example shows a simple implementation of MNIST Addition, which was first introduced in [Manhaeve et al., 2018](https://arxiv.org/abs/1805.10872). In this task, the inputs are pairs of MNIST handwritten images, and the outputs are their sums.\n", + "This notebook shows an implementation of [MNIST Addition](https://arxiv.org/abs/1805.10872). In this task, pairs of MNIST handwritten images and their sums are given, alongwith a domain knowledge base which contain information on how to perform addition operations. Our objective is to input a pair of handwritten images and accurately determine their sum.\n", "\n", - "In Abductive Learning, we hope to first use learning part to map the input images to their digits (we call it pseudo labels), and then use reasoning part to calculate the summation of these pseudo labels to get the final result." + "Intuitively, we first use a machine learning model (learning part) to convert the input images to digits (we call them pseudo labels), and then use the knowledge base (reasoning part) to calculate the sum of these digits. Since we do not have ground-truth of the digits, the reasoning part will leverage domain knowledge and revise the initial digits yielded by the learning part into results derived from abductive reasoning. This process enables us to further refine and retrain the machine learning model." ] }, { @@ -17,28 +17,27 @@ "metadata": {}, "outputs": [], "source": [ + "# Import necessary libraries and modules\n", "import os.path as osp\n", - "\n", "import torch\n", "import torch.nn as nn\n", "import matplotlib.pyplot as plt\n", - "\n", - "from abl.bridge import SimpleBridge\n", - "from abl.evaluation import ReasoningMetric, SymbolMetric\n", + "from examples.mnist_add.datasets import get_dataset\n", + "from examples.models.nn import LeNet5\n", "from abl.learning import ABLModel, BasicNN\n", "from abl.reasoning import KBBase, Reasoner\n", + "from abl.evaluation import ReasoningMetric, SymbolMetric\n", "from abl.utils import ABLLogger, print_log\n", - "from examples.mnist_add.datasets import get_mnist_add\n", - "from examples.models.nn import LeNet5" + "from abl.bridge import SimpleBridge" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Load Datasets\n", + "## Working with Data\n", "\n", - "First, we get training and testing data:" + "First, we get the training and testing datasets:" ] }, { @@ -47,15 +46,15 @@ "metadata": {}, "outputs": [], "source": [ - "train_data = get_mnist_add(train=True, get_pseudo_label=True)\n", - "test_data = get_mnist_add(train=False, get_pseudo_label=True)" + "train_data = get_dataset(train=True, get_pseudo_label=True)\n", + "test_data = get_dataset(train=False, get_pseudo_label=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "The datasets are illustrated as follows:" + "Both datasets contain several data examples. In each data example, we have three components: X (a pair of images), gt_pseudo_label (a pair of corresponding ground truth digits, i.e., pseudo labels), and Y (their sum). The datasets are illustrated as follows. " ] }, { @@ -68,8 +67,7 @@ "output_type": "stream", "text": [ "There are 30000 data examples in the training set and 5000 data examples in the test set\n", - "Each of the data example has 3 components: X, gt_pseudo_label, and Y.\n", - "As an illustration, in the First data example of the training set, we have:\n", + "As an illustration, in the first data example of the training set, we have:\n", "X (2 images):\n" ] }, @@ -94,8 +92,7 @@ ], "source": [ "print(f\"There are {len(train_data[0])} data examples in the training set and {len(test_data[0])} data examples in the test set\")\n", - "print(f\"Each of the data example has {len(train_data)} components: X, gt_pseudo_label, and Y.\")\n", - "print(\"As an illustration, in the First data example of the training set, we have:\")\n", + "print(\"As an illustration, in the first data example of the training set, we have:\")\n", "print(f\"X ({len(train_data[0][0])} images):\")\n", "plt.subplot(1,2,1)\n", "plt.axis('off') \n", @@ -113,14 +110,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Learning Part" + "## Building the Learning Part" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "First, we build the basic learning model. We use a simple [LeNet neural network](https://en.wikipedia.org/wiki/LeNet) to complete this task." + "To build the learning part, we need to first build a base machine learning model. We use a simple [LeNet-5 neural network](https://en.wikipedia.org/wiki/LeNet) to complete this task, and encapsulate it within a `BasicNN` object to create the base model. `BasicNN` is a class that encapsulates a PyTorch model, transforming it into a base model with an sklearn-style interface. " ] }, { @@ -131,7 +128,7 @@ "source": [ "cls = LeNet5(num_classes=10)\n", "loss_fn = nn.CrossEntropyLoss()\n", - "optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99))\n", + "optimizer = torch.optim.Adam(cls.parameters(), lr=0.001)\n", "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "base_model = BasicNN(\n", @@ -148,7 +145,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The base model can predict the outcome class index and the probabilities for an image, as shown below:" + "`BasicNN` offers methods like `predict` and `predict_prob`, which are used to predict the outcome class index and the probabilities for an image, respectively. As shown below:" ] }, { @@ -176,7 +173,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Then, we build an instance of `ABLModel`. The main function of `ABLModel` is to serialize data and provide a unified interface for different base machine learning models." + "However, base model built above are trained to make predictions on instance-level data, i.e., a single image, and can not directly utilize sample-level data, i.e., a pair of images. Therefore, we then wrap the base model into `ABLModel` which enables the learning part to train, test, and predict on sample-level data." ] }, { @@ -192,44 +189,63 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Logic Part" + "TODO: 示例展示ablmodel和base model的predict的不同" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "# from abl.structures import ListData\n", + "# data_samples = ListData()\n", + "# data_samples.X = [list(torch.randn(2, 1, 28, 28)) for _ in range(3)]\n", + "\n", + "# model.predict(data_samples)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Building the Reasoning Part" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "In the logic part, we first build a knowledge base." + "In the reasoning part, we first build a knowledge base which contain information on how to perform addition operations. We build it by creating a subclass of `KBBase`. In the derived subclass, we have to first initialize the `pseudo_label_list` parameter specifying list of possible pseudo labels, and then override the `logic_forward` function defining how to perform (deductive) reasoning." ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ - "# Build knowledge base and reasoner\n", "class AddKB(KBBase):\n", - " def __init__(self, pseudo_label_list):\n", + " def __init__(self, pseudo_label_list=list(range(10))):\n", " super().__init__(pseudo_label_list)\n", "\n", " # Implement the deduction function\n", " def logic_forward(self, nums):\n", " return sum(nums)\n", "\n", - "kb = AddKB(pseudo_label_list=list(range(10)))" + "kb = AddKB()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "The knowledge base can perform logical reasoning. Below is an example of performing (deductive) reasoning:" + "The knowledge base can perform logical reasoning. Below is an example of performing (deductive) reasoning: # TODO: ABDUCTIVE REASONING" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -250,16 +266,32 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Then, we create a reasoner. It can help minimize inconsistencies between the knowledge base and pseudo labels predicted by the learning part." + "Note: In addition to building a knowledge base based on `KBBase`, we can also establish a knowledge base with a ground KB using `GroundKB`, or a knowledge base implemented based on Prolog files using `PrologKB`. The corresponding code for these implementations can be found in the `main.py` file. Those interested are encouraged to examine it for further insights." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Then, we create a reasoner by instantiating the class ``Reasoner``. Due to the indeterminism of abductive reasoning, there could be multiple candidates compatible to the knowledge base. When this happens, reasoner can minimize inconsistencies between the knowledge base and pseudo labels predicted by the learning part, and then return only one candidate which has highest consistency." ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ - "reasoner = Reasoner(kb, dist_func=\"confidence\")" + "reasoner = Reasoner(kb)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note: During creating reasoner, the definition of \"consistency\" can be customized within the `dist_func` parameter. In the code above, we employ a consistency measurement based on confidence, which calculates the consistency between the data sample and candidates based on the confidence derived from the predicted probability. In `main.py`, we provide options for utilizing other forms of consistency measurement.\n", + "\n", + "Also, during process of inconsistency minimization, one can leverage [ZOOpt library](https://github.com/polixir/ZOOpt) for acceleration. Options for this are also available in `main.py`. Those interested are encouraged to explore these features." ] }, { @@ -267,19 +299,19 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Evaluation Metrics" + "## Building Evaluation Metrics" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Set up evaluation metrics. These metrics will be used to evaluate the model performance during training and testing." + "Next, we set up evaluation metrics. These metrics will be used to evaluate the model performance during training and testing. Specifically, we use `SymbolMetric` and `ReasoningMetric`, which are used to evaluate the accuracy of the machine learning model’s predictions and the accuracy of the final reasoning results, respectively." ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -293,12 +325,12 @@ "source": [ "## Bridge Learning and Reasoning\n", "\n", - "Now, the last step is to bridge the learning and reasoning part." + "Now, the last step is to bridge the learning and reasoning part. We proceed this step by creating an instance of `SimpleBridge`." ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -309,7 +341,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Perform training and testing." + "Perform training and testing by invoking the `train` and `test` methods of `SimpleBridge`." ] }, { @@ -344,7 +376,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.18" + "version": "3.8.13" }, "orig_nbformat": 4, "vscode": { diff --git a/examples/mnist_add/requirements.txt b/examples/mnist_add/requirements.txt index 887e75b..24fb7af 100644 --- a/examples/mnist_add/requirements.txt +++ b/examples/mnist_add/requirements.txt @@ -1,4 +1,2 @@ -torch -torchvision -torchaudio +abl matplotlib \ No newline at end of file