MNIST Addition ============== This example shows a simple implementation of MNIST Addition, which was first introduced in `Manhaeve et al., 2018 `__. In this task, the inputs are pairs of MNIST handwritten images, and the outputs are their sums. In Abductive Learning, we hope to first use learning part to map the input images to their digits (we call it pseudo labels), and then use reasoning part to calculate the summation of these pseudo labels to get the final result. .. code:: ipython3 import os.path as osp import torch import torch.nn as nn import matplotlib.pyplot as plt from abl.bridge import SimpleBridge from abl.evaluation import ReasoningMetric, SymbolMetric from abl.learning import ABLModel, BasicNN from abl.reasoning import KBBase, Reasoner from abl.utils import ABLLogger, print_log from examples.mnist_add.datasets import get_mnist_add from examples.models.nn import LeNet5 Load Datasets ------------- First, we get training and testing data: .. code:: ipython3 train_data = get_mnist_add(train=True, get_pseudo_label=True) test_data = get_mnist_add(train=False, get_pseudo_label=True) The datasets are illustrated as follows: .. code:: ipython3 print(f"There are {len(train_data[0])} data examples in the training set and {len(test_data[0])} data examples in the test set") print(f"Each of the data example has {len(train_data)} components: X, gt_pseudo_label, and Y.") print("As an illustration, in the First data example of the training set, we have:") print(f"X ({len(train_data[0][0])} images):") plt.subplot(1,2,1) plt.axis('off') plt.imshow(train_data[0][0][0].numpy().transpose(1, 2, 0)) plt.subplot(1,2,2) plt.axis('off') plt.imshow(train_data[0][0][1].numpy().transpose(1, 2, 0)) plt.show() print(f"gt_pseudo_label ({len(train_data[1][0])} ground truth pseudo label): {train_data[1][0][0]}, {train_data[1][0][1]}") print(f"Y (their sum result): {train_data[2][0]}") Out: .. code:: none :class: code-out There are 30000 data examples in the training set and 5000 data examples in the test set Each of the data example has 3 components: X, gt_pseudo_label, and Y. As an illustration, in the First data example of the training set, we have: X (2 images): .. image:: ../img/mnist_add_datasets.png :width: 400px .. code:: none :class: code-out gt_pseudo_label (2 ground truth pseudo label): 7, 5 Y (their sum result): 12 Learning Part ------------- First, we build the basic learning model. We use a simple `LeNet neural network `__ to complete this task. .. code:: ipython3 cls = LeNet5(num_classes=10) loss_fn = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99)) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") base_model = BasicNN( cls, loss_fn, optimizer, device, batch_size=32, num_epochs=1, ) The base model can predict the outcome class index and the probabilities for an image, as shown below: .. code:: ipython3 pred_idx = base_model.predict(X=[torch.randn(1, 28, 28).to(device) for _ in range(32)]) print(f"Shape of pred_idx for a batch of 32 samples: {pred_idx.shape}") pred_prob = base_model.predict_proba(X=[torch.randn(1, 28, 28).to(device) for _ in range(32)]) print(f"Shape of pred_prob for a batch of 32 samples: {pred_prob.shape}") Out: .. code:: none :class: code-out Shape of pred_idx for a batch of 32 samples: (32,) Shape of pred_prob for a batch of 32 samples: (32, 10) Then, we build an instance of ``ABLModel``. The main function of ``ABLModel`` is to serialize data and provide a unified interface for different base machine learning models. .. code:: ipython3 model = ABLModel(base_model) Logic Part ---------- In the logic part, we first build a knowledge base. .. code:: ipython3 # Build knowledge base and reasoner class AddKB(KBBase): def __init__(self, pseudo_label_list): super().__init__(pseudo_label_list) # Implement the deduction function def logic_forward(self, nums): return sum(nums) kb = AddKB(pseudo_label_list=list(range(10))) The knowledge base can perform logical reasoning. Below is an example of performing (deductive) reasoning: .. code:: ipython3 pseudo_label_sample = [1, 2] reasoning_result = kb.logic_forward(pseudo_label_sample) print(f"Reasoning result of pseudo label sample {pseudo_label_sample} is {reasoning_result}.") Out: .. code:: none :class: code-out Reasoning result of pseudo label sample [1, 2] is 3. Then, we create a reasoner. It can help minimize inconsistencies between the knowledge base and pseudo labels predicted by the learning part. .. code:: ipython3 reasoner = Reasoner(kb, dist_func="confidence") Evaluation Metrics ------------------ Set up evaluation metrics. These metrics will be used to evaluate the model performance during training and testing. .. code:: ipython3 metric_list = [SymbolMetric(prefix="mnist_add"), ReasoningMetric(kb=kb, prefix="mnist_add")] Bridge Learning and Reasoning ----------------------------- Now, the last step is to bridge the learning and reasoning part. .. code:: ipython3 bridge = SimpleBridge(model, reasoner, metric_list) Perform training and testing. .. code:: ipython3 # Build logger print_log("Abductive Learning on the MNIST Addition example.", logger="current") # Retrieve the directory of the Log file and define the directory for saving the model weights. log_dir = ABLLogger.get_current_instance().log_dir weights_dir = osp.join(log_dir, "weights") bridge.train(train_data, loops=5, segment_size=1/3, save_interval=1, save_dir=weights_dir) bridge.test(test_data) Out: .. code:: none :class: code-out abl - INFO - Abductive Learning on the MNIST Addition example. abl - INFO - loop(train) [1/5] segment(train) [1/3] abl - INFO - model loss: 1.81231 abl - INFO - loop(train) [1/5] segment(train) [2/3] abl - INFO - model loss: 1.37639 abl - INFO - loop(train) [1/5] segment(train) [3/3] abl - INFO - model loss: 1.14446 abl - INFO - Evaluation start: loop(val) [1] abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.207 mnist_add/reasoning_accuracy: 0.245 abl - INFO - Saving model: loop(save) [1] abl - INFO - Checkpoints will be saved to log_dir/weights/model_checkpoint_loop_1.pth abl - INFO - loop(train) [2/5] segment(train) [1/3] abl - INFO - model loss: 0.97430 abl - INFO - loop(train) [2/5] segment(train) [2/3] abl - INFO - model loss: 0.91448 abl - INFO - loop(train) [2/5] segment(train) [3/3] abl - INFO - model loss: 0.83089 abl - INFO - Evaluation start: loop(val) [2] abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.191 mnist_add/reasoning_accuracy: 0.353 abl - INFO - Saving model: loop(save) [2] abl - INFO - Checkpoints will be saved to log_dir/weights/model_checkpoint_loop_2.pth abl - INFO - loop(train) [3/5] segment(train) [1/3] abl - INFO - model loss: 0.79906 abl - INFO - loop(train) [3/5] segment(train) [2/3] abl - INFO - model loss: 0.77949 abl - INFO - loop(train) [3/5] segment(train) [3/3] abl - INFO - model loss: 0.75007 abl - INFO - Evaluation start: loop(val) [3] abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.148 mnist_add/reasoning_accuracy: 0.385 abl - INFO - Saving model: loop(save) [3] abl - INFO - Checkpoints will be saved to log_dir/weights/model_checkpoint_loop_3.pth abl - INFO - loop(train) [4/5] segment(train) [1/3] abl - INFO - model loss: 0.72659 abl - INFO - loop(train) [4/5] segment(train) [2/3] abl - INFO - model loss: 0.70985 abl - INFO - loop(train) [4/5] segment(train) [3/3] abl - INFO - model loss: 0.66337 abl - INFO - Evaluation start: loop(val) [4] abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.016 mnist_add/reasoning_accuracy: 0.494 abl - INFO - Saving model: loop(save) [4] abl - INFO - Checkpoints will be saved to log_dir/weights/model_checkpoint_loop_4.pth abl - INFO - loop(train) [5/5] segment(train) [1/3] abl - INFO - model loss: 0.61140 abl - INFO - loop(train) [5/5] segment(train) [2/3] abl - INFO - model loss: 0.57534 abl - INFO - loop(train) [5/5] segment(train) [3/3] abl - INFO - model loss: 0.57018 abl - INFO - Evaluation start: loop(val) [5] abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.002 mnist_add/reasoning_accuracy: 0.507 abl - INFO - Saving model: loop(save) [5] abl - INFO - Checkpoints will be saved to log_dir/weights/model_checkpoint_loop_5.pth abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.002 mnist_add/reasoning_accuracy: 0.482 More concrete examples are available in ``examples/mnist_add`` folder.