Browse Source

[DOC] modify mnist relative doc

pull/1/head
Gao Enhao 2 years ago
parent
commit
7ca19ecda2
5 changed files with 59 additions and 73 deletions
  1. +2
    -1
      abl/bridge/simple_bridge.py
  2. +9
    -10
      abl/learning/basic_nn.py
  3. +1
    -1
      abl/reasoning/reasoner.py
  4. +41
    -56
      docs/Examples/MNISTAdd.rst
  5. +6
    -5
      docs/Intro/Quick-Start.rst

+ 2
- 1
abl/bridge/simple_bridge.py View File

@@ -282,7 +282,7 @@ class SimpleBridge(BaseBridge):
self.model.train(sub_data_examples)

if (loop + 1) % eval_interval == 0 or loop == loops - 1:
print_log(f"Evaluation start: loop(val) [{loop + 1}]", logger="current")
print_log(f"Eval start: loop(val) [{loop + 1}]", logger="current")
self._valid(val_data_examples)

if save_interval is not None and ((loop + 1) % save_interval == 0 or loop == loops - 1):
@@ -349,5 +349,6 @@ class SimpleBridge(BaseBridge):
``gt_pseudo_label`` and ``Y`` attributes. Both ``gt_pseudo_label`` and ``Y`` can be either None or
not, which depends on the evaluation metircs in ``self.metric_list``.
"""
print_log("Test start:", logger="current")
test_data_examples = self.data_preprocess("test", test_data)
self._valid(test_data_examples)

+ 9
- 10
abl/learning/basic_nn.py View File

@@ -24,9 +24,10 @@ class BasicNN:
The loss function used for training.
optimizer : torch.optim.Optimizer
The optimizer used for training.
scheduler : torch.optim.lr_scheduler.LRScheduler
scheduler : Callable[..., Any], optional
The learning rate scheduler used for training, which will be called
at the end of each run of the ``fit`` method, by default None.
at the end of each run of the ``fit`` method. It should implement the
``step`` method, by default None.
device : torch.device, optional
The device on which the model will be trained or used for prediction,
by default torch.device("cpu").
@@ -34,13 +35,13 @@ class BasicNN:
The batch size used for training, by default 32.
num_epochs : int, optional
The number of epochs used for training, by default 1.
stop_loss : Optional[float], optional
stop_loss : float, optional
The loss value at which to stop training, by default 0.0001.
num_workers : int
The number of workers used for loading data, by default 0.
save_interval : Optional[int], optional
save_interval : int, optional
The model will be saved every ``save_interval`` epochs during training, by default None.
save_dir : Optional[str], optional
save_dir : str, optional
The directory in which to save the model during training, by default None.
train_transform : Callable[..., Any], optional
A function/transform that takes an object and returns a transformed version used
@@ -57,7 +58,7 @@ class BasicNN:
model: torch.nn.Module,
loss_fn: torch.nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None,
scheduler: Optional[Callable[..., Any]] = None,
device: torch.device = torch.device("cpu"),
batch_size: int = 32,
num_epochs: int = 1,
@@ -75,10 +76,8 @@ class BasicNN:
raise TypeError("loss_fn must be an instance of torch.nn.Module")
if not isinstance(optimizer, torch.optim.Optimizer):
raise TypeError("optimizer must be an instance of torch.optim.Optimizer")
if scheduler is not None and not isinstance(
scheduler, torch.optim.lr_scheduler.LRScheduler
):
raise TypeError("scheduler must be an instance of torch.optim.lr_scheduler.LRScheduler")
if scheduler is not None and not hasattr(scheduler, "step"):
raise NotImplementedError("scheduler should implement the ``step`` method")
if not isinstance(device, torch.device):
raise TypeError("device must be an instance of torch.device")
if not isinstance(batch_size, int):


+ 1
- 1
abl/reasoning/reasoner.py View File

@@ -32,7 +32,7 @@ class Reasoner:
in this cost list should be a numerical value representing the cost for each
candidate, and the list should have the same length as candidates.
Defaults to 'confidence'.
idx_to_label : Optional[dict], optional
idx_to_label : dict, optional
A mapping from index in the base model to label. If not provided, a default
order-based index to label mapping is created. Defaults to None.
max_revision : Union[int, float], optional


+ 41
- 56
docs/Examples/MNISTAdd.rst View File

@@ -24,6 +24,9 @@ machine learning model.
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

from torch.optim import RMSprop, lr_scheduler

from examples.mnist_add.datasets import get_dataset
from examples.models.nn import LeNet5
from abl.learning import ABLModel, BasicNN
@@ -139,15 +142,17 @@ model with an sklearn-style interface.
.. code:: ipython3

cls = LeNet5(num_classes=10)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001, alpha=0.9)
loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = RMSprop(cls.parameters(), lr=0.001, alpha=0.9)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=0.001, pct_start=0.1, total_steps=100)

base_model = BasicNN(
cls,
loss_fn,
optimizer,
device,
scheduler=scheduler,
device=device,
batch_size=32,
num_epochs=1,
)
@@ -328,8 +333,8 @@ methods of ``SimpleBridge``.
print_log("Abductive Learning on the MNIST Addition example.", logger="current")
log_dir = ABLLogger.get_current_instance().log_dir
weights_dir = osp.join(log_dir, "weights")
bridge.train(train_data, loops=5, segment_size=1/3, save_interval=1, save_dir=weights_dir)
bridge.train(train_data, loops=1, segment_size=0.01, save_interval=1, save_dir=weights_dir)
bridge.test(test_data)

Out:
@@ -337,56 +342,36 @@ Out:
:class: code-out

abl - INFO - Abductive Learning on the MNIST Addition example.
abl - INFO - loop(train) [1/5] segment(train) [1/3]
abl - INFO - model loss: 1.49104
abl - INFO - loop(train) [1/5] segment(train) [2/3]
abl - INFO - model loss: 1.24945
abl - INFO - loop(train) [1/5] segment(train) [3/3]
abl - INFO - model loss: 0.87861
abl - INFO - Evaluation start: loop(val) [1]
abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.818 mnist_add/reasoning_accuracy: 0.672
abl - INFO - loop(train) [1/1] segment(train) [1/100]
abl - INFO - model loss: 2.23587
abl - INFO - loop(train) [1/1] segment(train) [2/100]
abl - INFO - model loss: 2.23756
abl - INFO - loop(train) [1/1] segment(train) [3/100]
abl - INFO - model loss: 2.04475
abl - INFO - loop(train) [1/1] segment(train) [4/100]
abl - INFO - model loss: 2.01035
abl - INFO - loop(train) [1/1] segment(train) [5/100]
abl - INFO - model loss: 1.97584
abl - INFO - loop(train) [1/1] segment(train) [6/100]
abl - INFO - model loss: 1.91570
abl - INFO - loop(train) [1/1] segment(train) [7/100]
abl - INFO - model loss: 1.90268
abl - INFO - loop(train) [1/1] segment(train) [8/100]
abl - INFO - model loss: 1.77436
abl - INFO - loop(train) [1/1] segment(train) [9/100]
abl - INFO - model loss: 1.73454
abl - INFO - loop(train) [1/1] segment(train) [10/100]
abl - INFO - model loss: 1.62495
abl - INFO - loop(train) [1/1] segment(train) [11/100]
abl - INFO - model loss: 1.58456
abl - INFO - loop(train) [1/1] segment(train) [12/100]
abl - INFO - model loss: 1.62575
...
abl - INFO - Eval start: loop(val) [1]
abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.986 mnist_add/reasoning_accuracy: 0.973
abl - INFO - Saving model: loop(save) [1]
abl - INFO - Checkpoints will be saved to weights_dir/model_checkpoint_loop_1.pth
abl - INFO - loop(train) [2/5] segment(train) [1/3]
abl - INFO - model loss: 0.31148
abl - INFO - loop(train) [2/5] segment(train) [2/3]
abl - INFO - model loss: 0.09520
abl - INFO - loop(train) [2/5] segment(train) [3/3]
abl - INFO - model loss: 0.07402
abl - INFO - Evaluation start: loop(val) [2]
abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.982 mnist_add/reasoning_accuracy: 0.964
abl - INFO - Saving model: loop(save) [2]
abl - INFO - Checkpoints will be saved to weights_dir/model_checkpoint_loop_2.pth
abl - INFO - loop(train) [3/5] segment(train) [1/3]
abl - INFO - model loss: 0.06027
abl - INFO - loop(train) [3/5] segment(train) [2/3]
abl - INFO - model loss: 0.05341
abl - INFO - loop(train) [3/5] segment(train) [3/3]
abl - INFO - model loss: 0.04915
abl - INFO - Evaluation start: loop(val) [3]
abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.987 mnist_add/reasoning_accuracy: 0.975
abl - INFO - Saving model: loop(save) [3]
abl - INFO - Checkpoints will be saved to weights_dir/model_checkpoint_loop_3.pth
abl - INFO - loop(train) [4/5] segment(train) [1/3]
abl - INFO - model loss: 0.04413
abl - INFO - loop(train) [4/5] segment(train) [2/3]
abl - INFO - model loss: 0.04181
abl - INFO - loop(train) [4/5] segment(train) [3/3]
abl - INFO - model loss: 0.04127
abl - INFO - Evaluation start: loop(val) [4]
abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.990 mnist_add/reasoning_accuracy: 0.980
abl - INFO - Saving model: loop(save) [4]
abl - INFO - Checkpoints will be saved to weights_dir/model_checkpoint_loop_4.pth
abl - INFO - loop(train) [5/5] segment(train) [1/3]
abl - INFO - model loss: 0.03544
abl - INFO - loop(train) [5/5] segment(train) [2/3]
abl - INFO - model loss: 0.03092
abl - INFO - loop(train) [5/5] segment(train) [3/3]
abl - INFO - model loss: 0.03663
abl - INFO - Evaluation start: loop(val) [5]
abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.991 mnist_add/reasoning_accuracy: 0.982
abl - INFO - Saving model: loop(save) [5]
abl - INFO - Checkpoints will be saved to weights_dir/model_checkpoint_loop_5.pth
abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.987 mnist_add/reasoning_accuracy: 0.974
abl - INFO - Checkpoints will be saved to results/20231222_22_25_07/weights/model_checkpoint_loop_1.pth
abl - INFO - Test start:
abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.983 mnist_add/reasoning_accuracy: 0.967

More concrete examples are available in ``examples/mnist_add/main.py`` and ``examples/mnist_add/mnist_add.ipynb``.

+ 6
- 5
docs/Intro/Quick-Start.rst View File

@@ -52,17 +52,18 @@ To facilitate uniform processing, ABL-Package provides the ``BasicNN`` class to
from abl.learning import BasicNN

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99))
optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001, alpha=0.9)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
base_model = BasicNN(cls, loss_fn, optimizer, device)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.001, pct_start=0.1, total_steps=100)
base_model = BasicNN(cls, loss_fn, optimizer, scheduler=scheduler, device=device)

However, Base model built above are trained to make predictions on instance-level data (e.g., a single image), which is not suitable enough for our task. Therefore, we then wrap the ``base_model`` into an instance of ``ABLModel``. This class serves as a unified wrapper for base models, facilitating the learning part to train, test, and predict on example-level data, (e.g., images that comprise the equation).

.. code:: python

from abl.learning import ABLModel
from abl.learning import ABLModel

model = ABLModel(base_model)
model = ABLModel(base_model)

Read more about `building the learning part <Learning.html>`_.

@@ -132,7 +133,7 @@ Finally, we proceed with training and testing.

.. code:: python

bridge.train(train_data, loops=5, segment_size=1/3)
bridge.train(train_data, loops=1, segment_size=0.01)
bridge.test(test_data)

Read more about `bridging machine learning and reasoning <Bridge.html>`_.

Loading…
Cancel
Save