@@ -24,6 +24,9 @@ machine learning model.
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.optim import RMSprop, lr_scheduler
from examples.mnist_add.datasets import get_dataset
from examples.models.nn import LeNet5
from abl.learning import ABLModel, BasicNN
@@ -139,15 +142,17 @@ model with an sklearn-style interface.
.. code:: ipython3
cls = LeNet5(num_classes=10)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim. RMSprop(cls.parameters(), lr=0.001, alpha=0.9)
loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1 )
optimizer = RMSprop(cls.parameters(), lr=0.001, alpha=0.9)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=0.001, pct_start=0.1, total_steps=100)
base_model = BasicNN(
cls,
loss_fn,
optimizer,
device,
scheduler=scheduler,
device=device,
batch_size=32,
num_epochs=1,
)
@@ -328,8 +333,8 @@ methods of ``SimpleBridge``.
print_log("Abductive Learning on the MNIST Addition example.", logger="current")
log_dir = ABLLogger.get_current_instance().log_dir
weights_dir = osp.join(log_dir, "weights")
bridge.train(train_data, loops=5, segment_size=1/3 , save_interval=1, save_dir=weights_dir)
bridge.train(train_data, loops=1, segment_size=0.01 , save_interval=1, save_dir=weights_dir)
bridge.test(test_data)
Out:
@@ -337,56 +342,36 @@ Out:
:class: code-out
abl - INFO - Abductive Learning on the MNIST Addition example.
abl - INFO - loop(train) [1/5] segment(train) [1/3]
abl - INFO - model loss: 1.49104
abl - INFO - loop(train) [1/5] segment(train) [2/3]
abl - INFO - model loss: 1.24945
abl - INFO - loop(train) [1/5] segment(train) [3/3]
abl - INFO - model loss: 0.87861
abl - INFO - Evaluation start: loop(val) [1]
abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.818 mnist_add/reasoning_accuracy: 0.672
abl - INFO - loop(train) [1/1] segment(train) [1/100]
abl - INFO - model loss: 2.23587
abl - INFO - loop(train) [1/1] segment(train) [2/100]
abl - INFO - model loss: 2.23756
abl - INFO - loop(train) [1/1] segment(train) [3/100]
abl - INFO - model loss: 2.04475
abl - INFO - loop(train) [1/1] segment(train) [4/100]
abl - INFO - model loss: 2.01035
abl - INFO - loop(train) [1/1] segment(train) [5/100]
abl - INFO - model loss: 1.97584
abl - INFO - loop(train) [1/1] segment(train) [6/100]
abl - INFO - model loss: 1.91570
abl - INFO - loop(train) [1/1] segment(train) [7/100]
abl - INFO - model loss: 1.90268
abl - INFO - loop(train) [1/1] segment(train) [8/100]
abl - INFO - model loss: 1.77436
abl - INFO - loop(train) [1/1] segment(train) [9/100]
abl - INFO - model loss: 1.73454
abl - INFO - loop(train) [1/1] segment(train) [10/100]
abl - INFO - model loss: 1.62495
abl - INFO - loop(train) [1/1] segment(train) [11/100]
abl - INFO - model loss: 1.58456
abl - INFO - loop(train) [1/1] segment(train) [12/100]
abl - INFO - model loss: 1.62575
...
abl - INFO - Eval start: loop(val) [1]
abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.986 mnist_add/reasoning_accuracy: 0.973
abl - INFO - Saving model: loop(save) [1]
abl - INFO - Checkpoints will be saved to weights_dir/model_checkpoint_loop_1.pth
abl - INFO - loop(train) [2/5] segment(train) [1/3]
abl - INFO - model loss: 0.31148
abl - INFO - loop(train) [2/5] segment(train) [2/3]
abl - INFO - model loss: 0.09520
abl - INFO - loop(train) [2/5] segment(train) [3/3]
abl - INFO - model loss: 0.07402
abl - INFO - Evaluation start: loop(val) [2]
abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.982 mnist_add/reasoning_accuracy: 0.964
abl - INFO - Saving model: loop(save) [2]
abl - INFO - Checkpoints will be saved to weights_dir/model_checkpoint_loop_2.pth
abl - INFO - loop(train) [3/5] segment(train) [1/3]
abl - INFO - model loss: 0.06027
abl - INFO - loop(train) [3/5] segment(train) [2/3]
abl - INFO - model loss: 0.05341
abl - INFO - loop(train) [3/5] segment(train) [3/3]
abl - INFO - model loss: 0.04915
abl - INFO - Evaluation start: loop(val) [3]
abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.987 mnist_add/reasoning_accuracy: 0.975
abl - INFO - Saving model: loop(save) [3]
abl - INFO - Checkpoints will be saved to weights_dir/model_checkpoint_loop_3.pth
abl - INFO - loop(train) [4/5] segment(train) [1/3]
abl - INFO - model loss: 0.04413
abl - INFO - loop(train) [4/5] segment(train) [2/3]
abl - INFO - model loss: 0.04181
abl - INFO - loop(train) [4/5] segment(train) [3/3]
abl - INFO - model loss: 0.04127
abl - INFO - Evaluation start: loop(val) [4]
abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.990 mnist_add/reasoning_accuracy: 0.980
abl - INFO - Saving model: loop(save) [4]
abl - INFO - Checkpoints will be saved to weights_dir/model_checkpoint_loop_4.pth
abl - INFO - loop(train) [5/5] segment(train) [1/3]
abl - INFO - model loss: 0.03544
abl - INFO - loop(train) [5/5] segment(train) [2/3]
abl - INFO - model loss: 0.03092
abl - INFO - loop(train) [5/5] segment(train) [3/3]
abl - INFO - model loss: 0.03663
abl - INFO - Evaluation start: loop(val) [5]
abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.991 mnist_add/reasoning_accuracy: 0.982
abl - INFO - Saving model: loop(save) [5]
abl - INFO - Checkpoints will be saved to weights_dir/model_checkpoint_loop_5.pth
abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.987 mnist_add/reasoning_accuracy: 0.974
abl - INFO - Checkpoints will be saved to results/20231222_22_25_07/weights/model_checkpoint_loop_1.pth
abl - INFO - Test start:
abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.983 mnist_add/reasoning_accuracy: 0.967
More concrete examples are available in ``examples/mnist_add/main.py`` and ``examples/mnist_add/mnist_add.ipynb``.