diff --git a/examples/mnist_add/README.md b/examples/mnist_add/README.md new file mode 100644 index 0000000..3196994 --- /dev/null +++ b/examples/mnist_add/README.md @@ -0,0 +1,47 @@ +# MNIST Addition Example + +This example shows a simple implementation of [MNIST Addition](https://link) task, where the inputs are pairs of MNIST handwritten images, and the outputs are their sums. + +## Run + +``` +bash +pip install -r requirements.txt +python main.py +``` + +## Usage + +``` +usage: test.py [-h] [--no-cuda] [--epochs EPOCHS] [--lr LR] + [--weight-decay WEIGHT_DECAY] [--batch-size BATCH_SIZE] + [--loops LOOPS] [--segment_size SEGMENT_SIZE] + [--save_interval SAVE_INTERVAL] [--max-revision MAX_REVISION] + [--require-more-revision REQUIRE_MORE_REVISION] + [--prolog | --ground] + +MNIST Addition example + +optional arguments: + -h, --help show this help message and exit + --no-cuda disables CUDA training + --epochs EPOCHS number of epochs in each learning loop iteration + (default : 1) + --lr LR base learning rate (default : 0.001) + --weight-decay WEIGHT_DECAY + weight decay value (default : 0.03) + --batch-size BATCH_SIZE + batch size (default : 32) + --loops LOOPS number of loop iterations (default : 5) + --segment_size SEGMENT_SIZE + number of loop iterations (default : 1/3) + --save_interval SAVE_INTERVAL + save interval (default : 1) + --max-revision MAX_REVISION + maximum revision in reasoner (default : -1) + --require-more-revision REQUIRE_MORE_REVISION + require more revision in reasoner (default : 0) + --prolog use PrologKB (default: False) + --ground use GroundKB (default: False) + +``` diff --git a/examples/mnist_add/datasets/add.pl b/examples/mnist_add/add.pl similarity index 100% rename from examples/mnist_add/datasets/add.pl rename to examples/mnist_add/add.pl diff --git a/examples/mnist_add/datasets/__init__.py b/examples/mnist_add/datasets/__init__.py new file mode 100644 index 0000000..ecec715 --- /dev/null +++ b/examples/mnist_add/datasets/__init__.py @@ -0,0 +1,3 @@ +from .get_mnist_add import get_mnist_add + +__all__ = ["get_mnist_add"] \ No newline at end of file diff --git a/examples/mnist_add/datasets/get_mnist_add.py b/examples/mnist_add/datasets/get_mnist_add.py index 4bbb834..553f187 100644 --- a/examples/mnist_add/datasets/get_mnist_add.py +++ b/examples/mnist_add/datasets/get_mnist_add.py @@ -5,7 +5,6 @@ from torchvision.transforms import transforms CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) - def get_data(file, img_dataset, get_pseudo_label): X = [] if get_pseudo_label: @@ -24,7 +23,6 @@ def get_data(file, img_dataset, get_pseudo_label): else: return X, None, Y - def get_mnist_add(train=True, get_pseudo_label=False): transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] @@ -32,17 +30,9 @@ def get_mnist_add(train=True, get_pseudo_label=False): img_dataset = torchvision.datasets.MNIST( root=CURRENT_DIR, train=train, download=True, transform=transform ) - if train: file = os.path.join(CURRENT_DIR, "train_data.txt") else: file = os.path.join(CURRENT_DIR, "test_data.txt") return get_data(file, img_dataset, get_pseudo_label) - - -if __name__ == "__main__": - train_X, train_Y = get_mnist_add(train=True) - test_X, test_Y = get_mnist_add(train=False) - print(len(train_X), len(test_X)) - print(train_X[0][0].shape, train_X[0][1].shape, train_Y[0]) diff --git a/examples/mnist_add/main.py b/examples/mnist_add/main.py new file mode 100644 index 0000000..c2163c4 --- /dev/null +++ b/examples/mnist_add/main.py @@ -0,0 +1,118 @@ +import os +import os.path as osp +import argparse + +import torch +from torch import nn + +from abl.bridge import SimpleBridge +from abl.evaluation import ReasoningMetric, SymbolMetric +from abl.learning import ABLModel, BasicNN +from abl.reasoning import KBBase, GroundKB, PrologKB, Reasoner +from abl.utils import ABLLogger, print_log +from examples.mnist_add.datasets import get_mnist_add +from examples.models.nn import LeNet5 + +class AddKB(KBBase): + def __init__(self, pseudo_label_list=list(range(10))): + super().__init__(pseudo_label_list) + + def logic_forward(self, nums): + return sum(nums) + +class AddGroundKB(GroundKB): + def __init__(self, pseudo_label_list=list(range(10)), GKB_len_list=[2]): + super().__init__(pseudo_label_list, GKB_len_list) + + def logic_forward(self, nums): + return sum(nums) + +def main(): + parser = argparse.ArgumentParser(description='MNIST Addition example') + parser.add_argument('--no-cuda', action='store_true', default=False, + help='disables CUDA training') + parser.add_argument('--epochs', type=int, default=1, + help='number of epochs in each learning loop iteration (default : 1)') + parser.add_argument('--lr', type=float, default=1e-3, + help='base learning rate (default : 0.001)') + parser.add_argument('--weight-decay', type=int, default=3e-2, + help='weight decay value (default : 0.03)') + parser.add_argument('--batch-size', type=int, default=32, + help='batch size (default : 32)') + parser.add_argument('--loops', type=int, default=5, + help='number of loop iterations (default : 5)') + parser.add_argument('--segment_size', type=int or float, default=1/3, + help='number of loop iterations (default : 1/3)') + parser.add_argument('--save_interval', type=int, default=1, + help='save interval (default : 1)') + parser.add_argument('--max-revision', type=int or float, default=-1, + help='maximum revision in reasoner (default : -1)') + parser.add_argument('--require-more-revision', type=int, default=5, + help='require more revision in reasoner (default : 0)') + kb_type = parser.add_mutually_exclusive_group() + kb_type.add_argument("--prolog", action="store_true", default=False, + help='use PrologKB (default: False)') + kb_type.add_argument("--ground", action="store_true", default=False, + help='use GroundKB (default: False)') + + args = parser.parse_args() + + # Build logger + print_log("Abductive Learning on the MNIST Addition example.", logger="current") + + # Retrieve the directory of the Log file and define the directory for saving the model weights. + log_dir = ABLLogger.get_current_instance().log_dir + weights_dir = osp.join(log_dir, "weights") + + ### Learning Part + # Build necessary components for BasicNN + cls = LeNet5(num_classes=10) + loss_fn = nn.CrossEntropyLoss() + optimizer = torch.optim.Adam(cls.parameters(), lr=args.lr) + use_cuda = not args.no_cuda and torch.cuda.is_available() + device = torch.device("cuda" if use_cuda else "cpu") + + # Build BasicNN + # The function of BasicNN is to wrap NN models into the form of an sklearn estimator + base_model = BasicNN( + cls, + loss_fn, + optimizer, + device=device, + batch_size=args.batch_size, + num_epochs=args.epochs, + ) + + # Build ABLModel + # The main function of the ABL model is to serialize data and + # provide a unified interface for different machine learning models + model = ABLModel(base_model) + + if args.prolog: + kb = PrologKB(pseudo_label_list=list(range(10)), pl_file="add.pl") + elif args.ground: + kb = AddGroundKB() + else: + kb = AddKB() + reasoner = Reasoner(kb, dist_func="confidence", max_revision=args.max_revision, require_more_revision=args.require_more_revision) + + ### Datasets and Evaluation Metrics + # Get training and testing data + train_data = get_mnist_add(train=True, get_pseudo_label=True) + test_data = get_mnist_add(train=False, get_pseudo_label=True) + + # Set up metrics + metric_list = [SymbolMetric(prefix="mnist_add"), ReasoningMetric(kb=kb, prefix="mnist_add")] + + ### Bridge Machine Learning and Logic Reasoning + bridge = SimpleBridge(model, reasoner, metric_list) + + # Train and Test + bridge.train(train_data, loops=args.loops, segment_size=args.segment_size, save_interval=args.save_interval, save_dir=weights_dir) + bridge.test(test_data) + + + + +if __name__ == "__main__": + main() diff --git a/examples/mnist_add/mnist_add_example.ipynb b/examples/mnist_add/mnist_add_example.ipynb index 9193d1f..4631c08 100644 --- a/examples/mnist_add/mnist_add_example.ipynb +++ b/examples/mnist_add/mnist_add_example.ipynb @@ -10,13 +10,15 @@ "\n", "import torch\n", "import torch.nn as nn\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib.image as mpimg\n", "\n", "from abl.bridge import SimpleBridge\n", "from abl.evaluation import ReasoningMetric, SymbolMetric\n", "from abl.learning import ABLModel, BasicNN\n", "from abl.reasoning import KBBase, Reasoner\n", "from abl.utils import ABLLogger, print_log\n", - "from examples.mnist_add.datasets.get_mnist_add import get_mnist_add\n", + "from examples.mnist_add.datasets import get_mnist_add\n", "from examples.models.nn import LeNet5" ] }, @@ -27,19 +29,58 @@ "outputs": [], "source": [ "# Build logger\n", - "print_log(\"Abductive Learning on the MNIST Add example.\", logger=\"current\")\n", + "print_log(\"Abductive Learning on the MNIST Addition example.\", logger=\"current\")\n", "\n", "# Retrieve the directory of the Log file and define the directory for saving the model weights.\n", "log_dir = ABLLogger.get_current_instance().log_dir\n", "weights_dir = osp.join(log_dir, \"weights\")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load Datasets" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Get training and testing data\n", + "train_data = get_mnist_add(train=True, get_pseudo_label=True)\n", + "test_data = get_mnist_add(train=False, get_pseudo_label=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(f\"There are {len(train_data[0])} data examples in the training set and {len(test_data[0])} data examples in the test set\")\n", + "print(f\"Each of the data example has {len(train_data)} components: X, gt_pseudo_label, and Y.\")\n", + "print(\"For instance, in the First data example in the training set, we have:\")\n", + "print(f\"X ({len(train_data[0][0])} images):\")\n", + "plt.subplot(1,2,1)\n", + "plt.axis('off') \n", + "plt.imshow(train_data[0][0][0].numpy().transpose(1, 2, 0))\n", + "plt.subplot(1,2,2)\n", + "plt.axis('off') \n", + "plt.imshow(train_data[0][0][1].numpy().transpose(1, 2, 0))\n", + "plt.show()\n", + "print(f\"gt_pseudo_label ({len(train_data[1][0])} ground truth pseudo label): {train_data[1][0][0]}, {train_data[1][0][1]}\")\n", + "print(f\"Y (their sum result): {train_data[2][0]}\")" + ] + }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ - "### Machine Learning Part" + "### Learning Part" ] }, { @@ -120,17 +161,6 @@ "### Datasets and Evaluation Metrics" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Get training and testing data\n", - "train_data = get_mnist_add(train=True, get_pseudo_label=True)\n", - "test_data = get_mnist_add(train=False, get_pseudo_label=True)" - ] - }, { "cell_type": "code", "execution_count": null,