From 7ca19ecda28f5b8e69aab7acad0ec8b8c1ddba64 Mon Sep 17 00:00:00 2001 From: Gao Enhao Date: Fri, 22 Dec 2023 22:46:47 +0800 Subject: [PATCH] [DOC] modify mnist relative doc --- abl/bridge/simple_bridge.py | 3 +- abl/learning/basic_nn.py | 19 ++++---- abl/reasoning/reasoner.py | 2 +- docs/Examples/MNISTAdd.rst | 97 ++++++++++++++++--------------------- docs/Intro/Quick-Start.rst | 11 +++-- 5 files changed, 59 insertions(+), 73 deletions(-) diff --git a/abl/bridge/simple_bridge.py b/abl/bridge/simple_bridge.py index d0d39a1..e24706e 100644 --- a/abl/bridge/simple_bridge.py +++ b/abl/bridge/simple_bridge.py @@ -282,7 +282,7 @@ class SimpleBridge(BaseBridge): self.model.train(sub_data_examples) if (loop + 1) % eval_interval == 0 or loop == loops - 1: - print_log(f"Evaluation start: loop(val) [{loop + 1}]", logger="current") + print_log(f"Eval start: loop(val) [{loop + 1}]", logger="current") self._valid(val_data_examples) if save_interval is not None and ((loop + 1) % save_interval == 0 or loop == loops - 1): @@ -349,5 +349,6 @@ class SimpleBridge(BaseBridge): ``gt_pseudo_label`` and ``Y`` attributes. Both ``gt_pseudo_label`` and ``Y`` can be either None or not, which depends on the evaluation metircs in ``self.metric_list``. """ + print_log("Test start:", logger="current") test_data_examples = self.data_preprocess("test", test_data) self._valid(test_data_examples) diff --git a/abl/learning/basic_nn.py b/abl/learning/basic_nn.py index 721576a..15e58b6 100644 --- a/abl/learning/basic_nn.py +++ b/abl/learning/basic_nn.py @@ -24,9 +24,10 @@ class BasicNN: The loss function used for training. optimizer : torch.optim.Optimizer The optimizer used for training. - scheduler : torch.optim.lr_scheduler.LRScheduler + scheduler : Callable[..., Any], optional The learning rate scheduler used for training, which will be called - at the end of each run of the ``fit`` method, by default None. + at the end of each run of the ``fit`` method. It should implement the + ``step`` method, by default None. device : torch.device, optional The device on which the model will be trained or used for prediction, by default torch.device("cpu"). @@ -34,13 +35,13 @@ class BasicNN: The batch size used for training, by default 32. num_epochs : int, optional The number of epochs used for training, by default 1. - stop_loss : Optional[float], optional + stop_loss : float, optional The loss value at which to stop training, by default 0.0001. num_workers : int The number of workers used for loading data, by default 0. - save_interval : Optional[int], optional + save_interval : int, optional The model will be saved every ``save_interval`` epochs during training, by default None. - save_dir : Optional[str], optional + save_dir : str, optional The directory in which to save the model during training, by default None. train_transform : Callable[..., Any], optional A function/transform that takes an object and returns a transformed version used @@ -57,7 +58,7 @@ class BasicNN: model: torch.nn.Module, loss_fn: torch.nn.Module, optimizer: torch.optim.Optimizer, - scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None, + scheduler: Optional[Callable[..., Any]] = None, device: torch.device = torch.device("cpu"), batch_size: int = 32, num_epochs: int = 1, @@ -75,10 +76,8 @@ class BasicNN: raise TypeError("loss_fn must be an instance of torch.nn.Module") if not isinstance(optimizer, torch.optim.Optimizer): raise TypeError("optimizer must be an instance of torch.optim.Optimizer") - if scheduler is not None and not isinstance( - scheduler, torch.optim.lr_scheduler.LRScheduler - ): - raise TypeError("scheduler must be an instance of torch.optim.lr_scheduler.LRScheduler") + if scheduler is not None and not hasattr(scheduler, "step"): + raise NotImplementedError("scheduler should implement the ``step`` method") if not isinstance(device, torch.device): raise TypeError("device must be an instance of torch.device") if not isinstance(batch_size, int): diff --git a/abl/reasoning/reasoner.py b/abl/reasoning/reasoner.py index a7f2c4d..7209c9c 100644 --- a/abl/reasoning/reasoner.py +++ b/abl/reasoning/reasoner.py @@ -32,7 +32,7 @@ class Reasoner: in this cost list should be a numerical value representing the cost for each candidate, and the list should have the same length as candidates. Defaults to 'confidence'. - idx_to_label : Optional[dict], optional + idx_to_label : dict, optional A mapping from index in the base model to label. If not provided, a default order-based index to label mapping is created. Defaults to None. max_revision : Union[int, float], optional diff --git a/docs/Examples/MNISTAdd.rst b/docs/Examples/MNISTAdd.rst index 3f69b38..f8175ff 100644 --- a/docs/Examples/MNISTAdd.rst +++ b/docs/Examples/MNISTAdd.rst @@ -24,6 +24,9 @@ machine learning model. import torch import torch.nn as nn import matplotlib.pyplot as plt + + from torch.optim import RMSprop, lr_scheduler + from examples.mnist_add.datasets import get_dataset from examples.models.nn import LeNet5 from abl.learning import ABLModel, BasicNN @@ -139,15 +142,17 @@ model with an sklearn-style interface. .. code:: ipython3 cls = LeNet5(num_classes=10) - loss_fn = nn.CrossEntropyLoss() - optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001, alpha=0.9) + loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1) + optimizer = RMSprop(cls.parameters(), lr=0.001, alpha=0.9) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - + scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=0.001, pct_start=0.1, total_steps=100) + base_model = BasicNN( cls, loss_fn, optimizer, - device, + scheduler=scheduler, + device=device, batch_size=32, num_epochs=1, ) @@ -328,8 +333,8 @@ methods of ``SimpleBridge``. print_log("Abductive Learning on the MNIST Addition example.", logger="current") log_dir = ABLLogger.get_current_instance().log_dir weights_dir = osp.join(log_dir, "weights") - - bridge.train(train_data, loops=5, segment_size=1/3, save_interval=1, save_dir=weights_dir) + + bridge.train(train_data, loops=1, segment_size=0.01, save_interval=1, save_dir=weights_dir) bridge.test(test_data) Out: @@ -337,56 +342,36 @@ Out: :class: code-out abl - INFO - Abductive Learning on the MNIST Addition example. - abl - INFO - loop(train) [1/5] segment(train) [1/3] - abl - INFO - model loss: 1.49104 - abl - INFO - loop(train) [1/5] segment(train) [2/3] - abl - INFO - model loss: 1.24945 - abl - INFO - loop(train) [1/5] segment(train) [3/3] - abl - INFO - model loss: 0.87861 - abl - INFO - Evaluation start: loop(val) [1] - abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.818 mnist_add/reasoning_accuracy: 0.672 + abl - INFO - loop(train) [1/1] segment(train) [1/100] + abl - INFO - model loss: 2.23587 + abl - INFO - loop(train) [1/1] segment(train) [2/100] + abl - INFO - model loss: 2.23756 + abl - INFO - loop(train) [1/1] segment(train) [3/100] + abl - INFO - model loss: 2.04475 + abl - INFO - loop(train) [1/1] segment(train) [4/100] + abl - INFO - model loss: 2.01035 + abl - INFO - loop(train) [1/1] segment(train) [5/100] + abl - INFO - model loss: 1.97584 + abl - INFO - loop(train) [1/1] segment(train) [6/100] + abl - INFO - model loss: 1.91570 + abl - INFO - loop(train) [1/1] segment(train) [7/100] + abl - INFO - model loss: 1.90268 + abl - INFO - loop(train) [1/1] segment(train) [8/100] + abl - INFO - model loss: 1.77436 + abl - INFO - loop(train) [1/1] segment(train) [9/100] + abl - INFO - model loss: 1.73454 + abl - INFO - loop(train) [1/1] segment(train) [10/100] + abl - INFO - model loss: 1.62495 + abl - INFO - loop(train) [1/1] segment(train) [11/100] + abl - INFO - model loss: 1.58456 + abl - INFO - loop(train) [1/1] segment(train) [12/100] + abl - INFO - model loss: 1.62575 + ... + abl - INFO - Eval start: loop(val) [1] + abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.986 mnist_add/reasoning_accuracy: 0.973 abl - INFO - Saving model: loop(save) [1] - abl - INFO - Checkpoints will be saved to weights_dir/model_checkpoint_loop_1.pth - abl - INFO - loop(train) [2/5] segment(train) [1/3] - abl - INFO - model loss: 0.31148 - abl - INFO - loop(train) [2/5] segment(train) [2/3] - abl - INFO - model loss: 0.09520 - abl - INFO - loop(train) [2/5] segment(train) [3/3] - abl - INFO - model loss: 0.07402 - abl - INFO - Evaluation start: loop(val) [2] - abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.982 mnist_add/reasoning_accuracy: 0.964 - abl - INFO - Saving model: loop(save) [2] - abl - INFO - Checkpoints will be saved to weights_dir/model_checkpoint_loop_2.pth - abl - INFO - loop(train) [3/5] segment(train) [1/3] - abl - INFO - model loss: 0.06027 - abl - INFO - loop(train) [3/5] segment(train) [2/3] - abl - INFO - model loss: 0.05341 - abl - INFO - loop(train) [3/5] segment(train) [3/3] - abl - INFO - model loss: 0.04915 - abl - INFO - Evaluation start: loop(val) [3] - abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.987 mnist_add/reasoning_accuracy: 0.975 - abl - INFO - Saving model: loop(save) [3] - abl - INFO - Checkpoints will be saved to weights_dir/model_checkpoint_loop_3.pth - abl - INFO - loop(train) [4/5] segment(train) [1/3] - abl - INFO - model loss: 0.04413 - abl - INFO - loop(train) [4/5] segment(train) [2/3] - abl - INFO - model loss: 0.04181 - abl - INFO - loop(train) [4/5] segment(train) [3/3] - abl - INFO - model loss: 0.04127 - abl - INFO - Evaluation start: loop(val) [4] - abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.990 mnist_add/reasoning_accuracy: 0.980 - abl - INFO - Saving model: loop(save) [4] - abl - INFO - Checkpoints will be saved to weights_dir/model_checkpoint_loop_4.pth - abl - INFO - loop(train) [5/5] segment(train) [1/3] - abl - INFO - model loss: 0.03544 - abl - INFO - loop(train) [5/5] segment(train) [2/3] - abl - INFO - model loss: 0.03092 - abl - INFO - loop(train) [5/5] segment(train) [3/3] - abl - INFO - model loss: 0.03663 - abl - INFO - Evaluation start: loop(val) [5] - abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.991 mnist_add/reasoning_accuracy: 0.982 - abl - INFO - Saving model: loop(save) [5] - abl - INFO - Checkpoints will be saved to weights_dir/model_checkpoint_loop_5.pth - abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.987 mnist_add/reasoning_accuracy: 0.974 + abl - INFO - Checkpoints will be saved to results/20231222_22_25_07/weights/model_checkpoint_loop_1.pth + abl - INFO - Test start: + abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.983 mnist_add/reasoning_accuracy: 0.967 More concrete examples are available in ``examples/mnist_add/main.py`` and ``examples/mnist_add/mnist_add.ipynb``. \ No newline at end of file diff --git a/docs/Intro/Quick-Start.rst b/docs/Intro/Quick-Start.rst index 215f2e8..12d8ea8 100644 --- a/docs/Intro/Quick-Start.rst +++ b/docs/Intro/Quick-Start.rst @@ -52,17 +52,18 @@ To facilitate uniform processing, ABL-Package provides the ``BasicNN`` class to from abl.learning import BasicNN loss_fn = torch.nn.CrossEntropyLoss() - optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99)) + optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001, alpha=0.9) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - base_model = BasicNN(cls, loss_fn, optimizer, device) + scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.001, pct_start=0.1, total_steps=100) + base_model = BasicNN(cls, loss_fn, optimizer, scheduler=scheduler, device=device) However, Base model built above are trained to make predictions on instance-level data (e.g., a single image), which is not suitable enough for our task. Therefore, we then wrap the ``base_model`` into an instance of ``ABLModel``. This class serves as a unified wrapper for base models, facilitating the learning part to train, test, and predict on example-level data, (e.g., images that comprise the equation). .. code:: python - from abl.learning import ABLModel + from abl.learning import ABLModel - model = ABLModel(base_model) + model = ABLModel(base_model) Read more about `building the learning part `_. @@ -132,7 +133,7 @@ Finally, we proceed with training and testing. .. code:: python - bridge.train(train_data, loops=5, segment_size=1/3) + bridge.train(train_data, loops=1, segment_size=0.01) bridge.test(test_data) Read more about `bridging machine learning and reasoning `_.