Browse Source

[ENH] change example structure (not complete)

pull/1/head
troyyyyy 2 years ago
parent
commit
586baaa799
6 changed files with 212 additions and 24 deletions
  1. +47
    -0
      examples/mnist_add/README.md
  2. +0
    -0
      examples/mnist_add/add.pl
  3. +3
    -0
      examples/mnist_add/datasets/__init__.py
  4. +0
    -10
      examples/mnist_add/datasets/get_mnist_add.py
  5. +118
    -0
      examples/mnist_add/main.py
  6. +44
    -14
      examples/mnist_add/mnist_add_example.ipynb

+ 47
- 0
examples/mnist_add/README.md View File

@@ -0,0 +1,47 @@
# MNIST Addition Example

This example shows a simple implementation of [MNIST Addition](https://link) task, where the inputs are pairs of MNIST handwritten images, and the outputs are their sums.

## Run

```
bash
pip install -r requirements.txt
python main.py
```

## Usage

```
usage: test.py [-h] [--no-cuda] [--epochs EPOCHS] [--lr LR]
[--weight-decay WEIGHT_DECAY] [--batch-size BATCH_SIZE]
[--loops LOOPS] [--segment_size SEGMENT_SIZE]
[--save_interval SAVE_INTERVAL] [--max-revision MAX_REVISION]
[--require-more-revision REQUIRE_MORE_REVISION]
[--prolog | --ground]

MNIST Addition example

optional arguments:
-h, --help show this help message and exit
--no-cuda disables CUDA training
--epochs EPOCHS number of epochs in each learning loop iteration
(default : 1)
--lr LR base learning rate (default : 0.001)
--weight-decay WEIGHT_DECAY
weight decay value (default : 0.03)
--batch-size BATCH_SIZE
batch size (default : 32)
--loops LOOPS number of loop iterations (default : 5)
--segment_size SEGMENT_SIZE
number of loop iterations (default : 1/3)
--save_interval SAVE_INTERVAL
save interval (default : 1)
--max-revision MAX_REVISION
maximum revision in reasoner (default : -1)
--require-more-revision REQUIRE_MORE_REVISION
require more revision in reasoner (default : 0)
--prolog use PrologKB (default: False)
--ground use GroundKB (default: False)

```

examples/mnist_add/datasets/add.pl → examples/mnist_add/add.pl View File


+ 3
- 0
examples/mnist_add/datasets/__init__.py View File

@@ -0,0 +1,3 @@
from .get_mnist_add import get_mnist_add

__all__ = ["get_mnist_add"]

+ 0
- 10
examples/mnist_add/datasets/get_mnist_add.py View File

@@ -5,7 +5,6 @@ from torchvision.transforms import transforms

CURRENT_DIR = os.path.abspath(os.path.dirname(__file__))


def get_data(file, img_dataset, get_pseudo_label):
X = []
if get_pseudo_label:
@@ -24,7 +23,6 @@ def get_data(file, img_dataset, get_pseudo_label):
else:
return X, None, Y


def get_mnist_add(train=True, get_pseudo_label=False):
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
@@ -32,17 +30,9 @@ def get_mnist_add(train=True, get_pseudo_label=False):
img_dataset = torchvision.datasets.MNIST(
root=CURRENT_DIR, train=train, download=True, transform=transform
)

if train:
file = os.path.join(CURRENT_DIR, "train_data.txt")
else:
file = os.path.join(CURRENT_DIR, "test_data.txt")

return get_data(file, img_dataset, get_pseudo_label)


if __name__ == "__main__":
train_X, train_Y = get_mnist_add(train=True)
test_X, test_Y = get_mnist_add(train=False)
print(len(train_X), len(test_X))
print(train_X[0][0].shape, train_X[0][1].shape, train_Y[0])

+ 118
- 0
examples/mnist_add/main.py View File

@@ -0,0 +1,118 @@
import os
import os.path as osp
import argparse

import torch
from torch import nn

from abl.bridge import SimpleBridge
from abl.evaluation import ReasoningMetric, SymbolMetric
from abl.learning import ABLModel, BasicNN
from abl.reasoning import KBBase, GroundKB, PrologKB, Reasoner
from abl.utils import ABLLogger, print_log
from examples.mnist_add.datasets import get_mnist_add
from examples.models.nn import LeNet5

class AddKB(KBBase):
def __init__(self, pseudo_label_list=list(range(10))):
super().__init__(pseudo_label_list)

def logic_forward(self, nums):
return sum(nums)

class AddGroundKB(GroundKB):
def __init__(self, pseudo_label_list=list(range(10)), GKB_len_list=[2]):
super().__init__(pseudo_label_list, GKB_len_list)

def logic_forward(self, nums):
return sum(nums)

def main():
parser = argparse.ArgumentParser(description='MNIST Addition example')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--epochs', type=int, default=1,
help='number of epochs in each learning loop iteration (default : 1)')
parser.add_argument('--lr', type=float, default=1e-3,
help='base learning rate (default : 0.001)')
parser.add_argument('--weight-decay', type=int, default=3e-2,
help='weight decay value (default : 0.03)')
parser.add_argument('--batch-size', type=int, default=32,
help='batch size (default : 32)')
parser.add_argument('--loops', type=int, default=5,
help='number of loop iterations (default : 5)')
parser.add_argument('--segment_size', type=int or float, default=1/3,
help='number of loop iterations (default : 1/3)')
parser.add_argument('--save_interval', type=int, default=1,
help='save interval (default : 1)')
parser.add_argument('--max-revision', type=int or float, default=-1,
help='maximum revision in reasoner (default : -1)')
parser.add_argument('--require-more-revision', type=int, default=5,
help='require more revision in reasoner (default : 0)')
kb_type = parser.add_mutually_exclusive_group()
kb_type.add_argument("--prolog", action="store_true", default=False,
help='use PrologKB (default: False)')
kb_type.add_argument("--ground", action="store_true", default=False,
help='use GroundKB (default: False)')

args = parser.parse_args()
# Build logger
print_log("Abductive Learning on the MNIST Addition example.", logger="current")

# Retrieve the directory of the Log file and define the directory for saving the model weights.
log_dir = ABLLogger.get_current_instance().log_dir
weights_dir = osp.join(log_dir, "weights")

### Learning Part
# Build necessary components for BasicNN
cls = LeNet5(num_classes=10)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(cls.parameters(), lr=args.lr)
use_cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

# Build BasicNN
# The function of BasicNN is to wrap NN models into the form of an sklearn estimator
base_model = BasicNN(
cls,
loss_fn,
optimizer,
device=device,
batch_size=args.batch_size,
num_epochs=args.epochs,
)

# Build ABLModel
# The main function of the ABL model is to serialize data and
# provide a unified interface for different machine learning models
model = ABLModel(base_model)
if args.prolog:
kb = PrologKB(pseudo_label_list=list(range(10)), pl_file="add.pl")
elif args.ground:
kb = AddGroundKB()
else:
kb = AddKB()
reasoner = Reasoner(kb, dist_func="confidence", max_revision=args.max_revision, require_more_revision=args.require_more_revision)

### Datasets and Evaluation Metrics
# Get training and testing data
train_data = get_mnist_add(train=True, get_pseudo_label=True)
test_data = get_mnist_add(train=False, get_pseudo_label=True)

# Set up metrics
metric_list = [SymbolMetric(prefix="mnist_add"), ReasoningMetric(kb=kb, prefix="mnist_add")]

### Bridge Machine Learning and Logic Reasoning
bridge = SimpleBridge(model, reasoner, metric_list)

# Train and Test
bridge.train(train_data, loops=args.loops, segment_size=args.segment_size, save_interval=args.save_interval, save_dir=weights_dir)
bridge.test(test_data)



if __name__ == "__main__":
main()

+ 44
- 14
examples/mnist_add/mnist_add_example.ipynb View File

@@ -10,13 +10,15 @@
"\n",
"import torch\n",
"import torch.nn as nn\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.image as mpimg\n",
"\n",
"from abl.bridge import SimpleBridge\n",
"from abl.evaluation import ReasoningMetric, SymbolMetric\n",
"from abl.learning import ABLModel, BasicNN\n",
"from abl.reasoning import KBBase, Reasoner\n",
"from abl.utils import ABLLogger, print_log\n",
"from examples.mnist_add.datasets.get_mnist_add import get_mnist_add\n",
"from examples.mnist_add.datasets import get_mnist_add\n",
"from examples.models.nn import LeNet5"
]
},
@@ -27,19 +29,58 @@
"outputs": [],
"source": [
"# Build logger\n",
"print_log(\"Abductive Learning on the MNIST Add example.\", logger=\"current\")\n",
"print_log(\"Abductive Learning on the MNIST Addition example.\", logger=\"current\")\n",
"\n",
"# Retrieve the directory of the Log file and define the directory for saving the model weights.\n",
"log_dir = ABLLogger.get_current_instance().log_dir\n",
"weights_dir = osp.join(log_dir, \"weights\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Load Datasets"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Get training and testing data\n",
"train_data = get_mnist_add(train=True, get_pseudo_label=True)\n",
"test_data = get_mnist_add(train=False, get_pseudo_label=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(f\"There are {len(train_data[0])} data examples in the training set and {len(test_data[0])} data examples in the test set\")\n",
"print(f\"Each of the data example has {len(train_data)} components: X, gt_pseudo_label, and Y.\")\n",
"print(\"For instance, in the First data example in the training set, we have:\")\n",
"print(f\"X ({len(train_data[0][0])} images):\")\n",
"plt.subplot(1,2,1)\n",
"plt.axis('off') \n",
"plt.imshow(train_data[0][0][0].numpy().transpose(1, 2, 0))\n",
"plt.subplot(1,2,2)\n",
"plt.axis('off') \n",
"plt.imshow(train_data[0][0][1].numpy().transpose(1, 2, 0))\n",
"plt.show()\n",
"print(f\"gt_pseudo_label ({len(train_data[1][0])} ground truth pseudo label): {train_data[1][0][0]}, {train_data[1][0][1]}\")\n",
"print(f\"Y (their sum result): {train_data[2][0]}\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Machine Learning Part"
"### Learning Part"
]
},
{
@@ -120,17 +161,6 @@
"### Datasets and Evaluation Metrics"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Get training and testing data\n",
"train_data = get_mnist_add(train=True, get_pseudo_label=True)\n",
"test_data = get_mnist_add(train=False, get_pseudo_label=True)"
]
},
{
"cell_type": "code",
"execution_count": null,


Loading…
Cancel
Save