| @@ -0,0 +1,47 @@ | |||
| # MNIST Addition Example | |||
| This example shows a simple implementation of [MNIST Addition](https://link) task, where the inputs are pairs of MNIST handwritten images, and the outputs are their sums. | |||
| ## Run | |||
| ``` | |||
| bash | |||
| pip install -r requirements.txt | |||
| python main.py | |||
| ``` | |||
| ## Usage | |||
| ``` | |||
| usage: test.py [-h] [--no-cuda] [--epochs EPOCHS] [--lr LR] | |||
| [--weight-decay WEIGHT_DECAY] [--batch-size BATCH_SIZE] | |||
| [--loops LOOPS] [--segment_size SEGMENT_SIZE] | |||
| [--save_interval SAVE_INTERVAL] [--max-revision MAX_REVISION] | |||
| [--require-more-revision REQUIRE_MORE_REVISION] | |||
| [--prolog | --ground] | |||
| MNIST Addition example | |||
| optional arguments: | |||
| -h, --help show this help message and exit | |||
| --no-cuda disables CUDA training | |||
| --epochs EPOCHS number of epochs in each learning loop iteration | |||
| (default : 1) | |||
| --lr LR base learning rate (default : 0.001) | |||
| --weight-decay WEIGHT_DECAY | |||
| weight decay value (default : 0.03) | |||
| --batch-size BATCH_SIZE | |||
| batch size (default : 32) | |||
| --loops LOOPS number of loop iterations (default : 5) | |||
| --segment_size SEGMENT_SIZE | |||
| number of loop iterations (default : 1/3) | |||
| --save_interval SAVE_INTERVAL | |||
| save interval (default : 1) | |||
| --max-revision MAX_REVISION | |||
| maximum revision in reasoner (default : -1) | |||
| --require-more-revision REQUIRE_MORE_REVISION | |||
| require more revision in reasoner (default : 0) | |||
| --prolog use PrologKB (default: False) | |||
| --ground use GroundKB (default: False) | |||
| ``` | |||
| @@ -0,0 +1,3 @@ | |||
| from .get_mnist_add import get_mnist_add | |||
| __all__ = ["get_mnist_add"] | |||
| @@ -5,7 +5,6 @@ from torchvision.transforms import transforms | |||
| CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) | |||
| def get_data(file, img_dataset, get_pseudo_label): | |||
| X = [] | |||
| if get_pseudo_label: | |||
| @@ -24,7 +23,6 @@ def get_data(file, img_dataset, get_pseudo_label): | |||
| else: | |||
| return X, None, Y | |||
| def get_mnist_add(train=True, get_pseudo_label=False): | |||
| transform = transforms.Compose( | |||
| [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] | |||
| @@ -32,17 +30,9 @@ def get_mnist_add(train=True, get_pseudo_label=False): | |||
| img_dataset = torchvision.datasets.MNIST( | |||
| root=CURRENT_DIR, train=train, download=True, transform=transform | |||
| ) | |||
| if train: | |||
| file = os.path.join(CURRENT_DIR, "train_data.txt") | |||
| else: | |||
| file = os.path.join(CURRENT_DIR, "test_data.txt") | |||
| return get_data(file, img_dataset, get_pseudo_label) | |||
| if __name__ == "__main__": | |||
| train_X, train_Y = get_mnist_add(train=True) | |||
| test_X, test_Y = get_mnist_add(train=False) | |||
| print(len(train_X), len(test_X)) | |||
| print(train_X[0][0].shape, train_X[0][1].shape, train_Y[0]) | |||
| @@ -0,0 +1,118 @@ | |||
| import os | |||
| import os.path as osp | |||
| import argparse | |||
| import torch | |||
| from torch import nn | |||
| from abl.bridge import SimpleBridge | |||
| from abl.evaluation import ReasoningMetric, SymbolMetric | |||
| from abl.learning import ABLModel, BasicNN | |||
| from abl.reasoning import KBBase, GroundKB, PrologKB, Reasoner | |||
| from abl.utils import ABLLogger, print_log | |||
| from examples.mnist_add.datasets import get_mnist_add | |||
| from examples.models.nn import LeNet5 | |||
| class AddKB(KBBase): | |||
| def __init__(self, pseudo_label_list=list(range(10))): | |||
| super().__init__(pseudo_label_list) | |||
| def logic_forward(self, nums): | |||
| return sum(nums) | |||
| class AddGroundKB(GroundKB): | |||
| def __init__(self, pseudo_label_list=list(range(10)), GKB_len_list=[2]): | |||
| super().__init__(pseudo_label_list, GKB_len_list) | |||
| def logic_forward(self, nums): | |||
| return sum(nums) | |||
| def main(): | |||
| parser = argparse.ArgumentParser(description='MNIST Addition example') | |||
| parser.add_argument('--no-cuda', action='store_true', default=False, | |||
| help='disables CUDA training') | |||
| parser.add_argument('--epochs', type=int, default=1, | |||
| help='number of epochs in each learning loop iteration (default : 1)') | |||
| parser.add_argument('--lr', type=float, default=1e-3, | |||
| help='base learning rate (default : 0.001)') | |||
| parser.add_argument('--weight-decay', type=int, default=3e-2, | |||
| help='weight decay value (default : 0.03)') | |||
| parser.add_argument('--batch-size', type=int, default=32, | |||
| help='batch size (default : 32)') | |||
| parser.add_argument('--loops', type=int, default=5, | |||
| help='number of loop iterations (default : 5)') | |||
| parser.add_argument('--segment_size', type=int or float, default=1/3, | |||
| help='number of loop iterations (default : 1/3)') | |||
| parser.add_argument('--save_interval', type=int, default=1, | |||
| help='save interval (default : 1)') | |||
| parser.add_argument('--max-revision', type=int or float, default=-1, | |||
| help='maximum revision in reasoner (default : -1)') | |||
| parser.add_argument('--require-more-revision', type=int, default=5, | |||
| help='require more revision in reasoner (default : 0)') | |||
| kb_type = parser.add_mutually_exclusive_group() | |||
| kb_type.add_argument("--prolog", action="store_true", default=False, | |||
| help='use PrologKB (default: False)') | |||
| kb_type.add_argument("--ground", action="store_true", default=False, | |||
| help='use GroundKB (default: False)') | |||
| args = parser.parse_args() | |||
| # Build logger | |||
| print_log("Abductive Learning on the MNIST Addition example.", logger="current") | |||
| # Retrieve the directory of the Log file and define the directory for saving the model weights. | |||
| log_dir = ABLLogger.get_current_instance().log_dir | |||
| weights_dir = osp.join(log_dir, "weights") | |||
| ### Learning Part | |||
| # Build necessary components for BasicNN | |||
| cls = LeNet5(num_classes=10) | |||
| loss_fn = nn.CrossEntropyLoss() | |||
| optimizer = torch.optim.Adam(cls.parameters(), lr=args.lr) | |||
| use_cuda = not args.no_cuda and torch.cuda.is_available() | |||
| device = torch.device("cuda" if use_cuda else "cpu") | |||
| # Build BasicNN | |||
| # The function of BasicNN is to wrap NN models into the form of an sklearn estimator | |||
| base_model = BasicNN( | |||
| cls, | |||
| loss_fn, | |||
| optimizer, | |||
| device=device, | |||
| batch_size=args.batch_size, | |||
| num_epochs=args.epochs, | |||
| ) | |||
| # Build ABLModel | |||
| # The main function of the ABL model is to serialize data and | |||
| # provide a unified interface for different machine learning models | |||
| model = ABLModel(base_model) | |||
| if args.prolog: | |||
| kb = PrologKB(pseudo_label_list=list(range(10)), pl_file="add.pl") | |||
| elif args.ground: | |||
| kb = AddGroundKB() | |||
| else: | |||
| kb = AddKB() | |||
| reasoner = Reasoner(kb, dist_func="confidence", max_revision=args.max_revision, require_more_revision=args.require_more_revision) | |||
| ### Datasets and Evaluation Metrics | |||
| # Get training and testing data | |||
| train_data = get_mnist_add(train=True, get_pseudo_label=True) | |||
| test_data = get_mnist_add(train=False, get_pseudo_label=True) | |||
| # Set up metrics | |||
| metric_list = [SymbolMetric(prefix="mnist_add"), ReasoningMetric(kb=kb, prefix="mnist_add")] | |||
| ### Bridge Machine Learning and Logic Reasoning | |||
| bridge = SimpleBridge(model, reasoner, metric_list) | |||
| # Train and Test | |||
| bridge.train(train_data, loops=args.loops, segment_size=args.segment_size, save_interval=args.save_interval, save_dir=weights_dir) | |||
| bridge.test(test_data) | |||
| if __name__ == "__main__": | |||
| main() | |||
| @@ -10,13 +10,15 @@ | |||
| "\n", | |||
| "import torch\n", | |||
| "import torch.nn as nn\n", | |||
| "import matplotlib.pyplot as plt\n", | |||
| "import matplotlib.image as mpimg\n", | |||
| "\n", | |||
| "from abl.bridge import SimpleBridge\n", | |||
| "from abl.evaluation import ReasoningMetric, SymbolMetric\n", | |||
| "from abl.learning import ABLModel, BasicNN\n", | |||
| "from abl.reasoning import KBBase, Reasoner\n", | |||
| "from abl.utils import ABLLogger, print_log\n", | |||
| "from examples.mnist_add.datasets.get_mnist_add import get_mnist_add\n", | |||
| "from examples.mnist_add.datasets import get_mnist_add\n", | |||
| "from examples.models.nn import LeNet5" | |||
| ] | |||
| }, | |||
| @@ -27,19 +29,58 @@ | |||
| "outputs": [], | |||
| "source": [ | |||
| "# Build logger\n", | |||
| "print_log(\"Abductive Learning on the MNIST Add example.\", logger=\"current\")\n", | |||
| "print_log(\"Abductive Learning on the MNIST Addition example.\", logger=\"current\")\n", | |||
| "\n", | |||
| "# Retrieve the directory of the Log file and define the directory for saving the model weights.\n", | |||
| "log_dir = ABLLogger.get_current_instance().log_dir\n", | |||
| "weights_dir = osp.join(log_dir, \"weights\")" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "### Load Datasets" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "# Get training and testing data\n", | |||
| "train_data = get_mnist_add(train=True, get_pseudo_label=True)\n", | |||
| "test_data = get_mnist_add(train=False, get_pseudo_label=True)" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "print(f\"There are {len(train_data[0])} data examples in the training set and {len(test_data[0])} data examples in the test set\")\n", | |||
| "print(f\"Each of the data example has {len(train_data)} components: X, gt_pseudo_label, and Y.\")\n", | |||
| "print(\"For instance, in the First data example in the training set, we have:\")\n", | |||
| "print(f\"X ({len(train_data[0][0])} images):\")\n", | |||
| "plt.subplot(1,2,1)\n", | |||
| "plt.axis('off') \n", | |||
| "plt.imshow(train_data[0][0][0].numpy().transpose(1, 2, 0))\n", | |||
| "plt.subplot(1,2,2)\n", | |||
| "plt.axis('off') \n", | |||
| "plt.imshow(train_data[0][0][1].numpy().transpose(1, 2, 0))\n", | |||
| "plt.show()\n", | |||
| "print(f\"gt_pseudo_label ({len(train_data[1][0])} ground truth pseudo label): {train_data[1][0][0]}, {train_data[1][0][1]}\")\n", | |||
| "print(f\"Y (their sum result): {train_data[2][0]}\")" | |||
| ] | |||
| }, | |||
| { | |||
| "attachments": {}, | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "### Machine Learning Part" | |||
| "### Learning Part" | |||
| ] | |||
| }, | |||
| { | |||
| @@ -120,17 +161,6 @@ | |||
| "### Datasets and Evaluation Metrics" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "# Get training and testing data\n", | |||
| "train_data = get_mnist_add(train=True, get_pseudo_label=True)\n", | |||
| "test_data = get_mnist_add(train=False, get_pseudo_label=True)" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||