Browse Source

[ENH] modify ModelConverter and add relative mnist example

pull/1/head
Gao Enhao 2 years ago
parent
commit
c0e4ad8726
2 changed files with 175 additions and 15 deletions
  1. +27
    -15
      abl/learning/model_converter.py
  2. +148
    -0
      examples/mnist_add/main_with_model_converter.py

+ 27
- 15
abl/learning/model_converter.py View File

@@ -17,8 +17,8 @@ class ModelConverter:
self,
lambdalearn_model,
loss_fn: torch.nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: Optional[Callable[..., Any]] = None,
optimizer_dict: dict,
scheduler_dict: Optional[dict] = None,
device: Optional[torch.device] = None,
batch_size: int = 32,
num_epochs: int = 1,
@@ -39,11 +39,13 @@ class ModelConverter:
The LambdaLearn model to be converted.
loss_fn : torch.nn.Module
The loss function used for training.
optimizer : torch.optim.Optimizer
The optimizer used for training.
scheduler : Callable[..., Any], optional
The learning rate scheduler used for training, which will be called
at the end of each run of the ``fit`` method. It should implement the
optimizer_dict : dict
The dict contains necessary parameters to construct a optimizer used for training.
The optimizer class is specified by the ``optimizer`` key.
scheduler_dict : dict, optional
The dict contains necessary parameters to construct a learning rate scheduler used
for training, which will be called at the end of each run of the ``fit`` method.
The scheduler class is specified by the ``scheduler`` key. It should implement the
``step`` method, by default None.
device : torch.device, optional
The device on which the model will be trained or used for prediction,
@@ -75,7 +77,7 @@ class ModelConverter:
The converted ABLModel instance.
'''
if isinstance(lambdalearn_model, DeepModelMixin):
base_model = self.convert_lambdalearn_to_basicnn(lambdalearn_model, loss_fn, optimizer, scheduler, device, batch_size, num_epochs, stop_loss, num_workers, save_interval, save_dir, train_transform, test_transform, collate_fn)
base_model = self.convert_lambdalearn_to_basicnn(lambdalearn_model, loss_fn, optimizer_dict, scheduler_dict, device, batch_size, num_epochs, stop_loss, num_workers, save_interval, save_dir, train_transform, test_transform, collate_fn)
return ABLModel(base_model)
if not (hasattr(lambdalearn_model, "fit") and hasattr(lambdalearn_model, "predict")):
@@ -88,8 +90,8 @@ class ModelConverter:
self,
lambdalearn_model: DeepModelMixin,
loss_fn: torch.nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: Optional[Callable[..., Any]] = None,
optimizer_dict: dict,
scheduler_dict: Optional[dict] = None,
device: Optional[torch.device] = None,
batch_size: int = 32,
num_epochs: int = 1,
@@ -110,11 +112,12 @@ class ModelConverter:
The LambdaLearn model to be converted.
loss_fn : torch.nn.Module
The loss function used for training.
optimizer : torch.optim.Optimizer
The optimizer used for training.
scheduler : Callable[..., Any], optional
The learning rate scheduler used for training, which will be called
at the end of each run of the ``fit`` method. It should implement the
optimizer_dict : dict
The dict contains necessary parameters to construct a optimizer used for training.
scheduler_dict : dict, optional
The dict contains necessary parameters to construct a learning rate scheduler used
for training, which will be called at the end of each run of the ``fit`` method.
The scheduler class is specified by the ``scheduler`` key. It should implement the
``step`` method, by default None.
device : torch.device, optional
The device on which the model will be trained or used for prediction,
@@ -150,6 +153,15 @@ class ModelConverter:
raise NotImplementedError(f"Expected lambdalearn_model.network to be a torch.nn.Module, but got {type(lambdalearn_model.network)}")
# Only use the network part and device of the lambdalearn model
network = copy.deepcopy(lambdalearn_model.network)
optimizer_class = optimizer_dict["optimizer"]
optimizer_dict.pop("optimizer")
optimizer = optimizer_class(network.parameters(), **optimizer_dict)
if scheduler_dict is not None:
scheduler_class = scheduler_dict["scheduler"]
scheduler_dict.pop("scheduler")
scheduler = scheduler_class(optimizer, **scheduler_dict)
else:
scheduler = None
device = lambdalearn_model.device if device is None else device
base_model = BasicNN(model=network, loss_fn=loss_fn, optimizer=optimizer, scheduler=scheduler, device=device, batch_size=batch_size, num_epochs=num_epochs, stop_loss=stop_loss, num_workers=num_workers, save_interval=save_interval, save_dir=save_dir, train_transform=train_transform, test_transform=test_transform, collate_fn=collate_fn)
return base_model


+ 148
- 0
examples/mnist_add/main_with_model_converter.py View File

@@ -0,0 +1,148 @@
import argparse
import os.path as osp

import torch
from torch import nn
from torch.optim import RMSprop, lr_scheduler

from lambdaLearn.Algorithm.AbductiveLearning.bridge import SimpleBridge
from lambdaLearn.Algorithm.AbductiveLearning.data.evaluation import ReasoningMetric, SymbolAccuracy
from lambdaLearn.Algorithm.AbductiveLearning.learning import ABLModel
from lambdaLearn.Algorithm.AbductiveLearning.learning.model_converter import ModelConverter
from lambdaLearn.Algorithm.AbductiveLearning.reasoning import GroundKB, KBBase, PrologKB, Reasoner
from lambdaLearn.Algorithm.AbductiveLearning.utils import ABLLogger, print_log
from lambdaLearn.Algorithm.Classification.FixMatch import FixMatch

from datasets import get_dataset
from models.nn import LeNet5


class AddKB(KBBase):
def __init__(self, pseudo_label_list=list(range(10))):
super().__init__(pseudo_label_list)

def logic_forward(self, nums):
return sum(nums)


class AddGroundKB(GroundKB):
def __init__(self, pseudo_label_list=list(range(10)), GKB_len_list=[2]):
super().__init__(pseudo_label_list, GKB_len_list)

def logic_forward(self, nums):
return sum(nums)


def main():
parser = argparse.ArgumentParser(description="MNIST Addition example")
parser.add_argument(
"--no-cuda", action="store_true", default=False, help="disables CUDA training"
)
parser.add_argument(
"--epochs",
type=int,
default=1,
help="number of epochs in each learning loop iteration (default : 1)",
)
parser.add_argument(
"--lr", type=float, default=3e-4, help="base model learning rate (default : 0.0003)"
)
parser.add_argument("--alpha", type=float, default=0.9, help="alpha in RMSprop (default : 0.9)")
parser.add_argument(
"--batch-size", type=int, default=32, help="base model batch size (default : 32)"
)
parser.add_argument(
"--loops", type=int, default=2, help="number of loop iterations (default : 2)"
)
parser.add_argument(
"--segment_size", type=int, default=0.01, help="segment size (default : 0.01)"
)
parser.add_argument("--save_interval", type=int, default=1, help="save interval (default : 1)")
parser.add_argument(
"--max-revision",
type=int,
default=-1,
help="maximum revision in reasoner (default : -1)",
)
parser.add_argument(
"--require-more-revision",
type=int,
default=0,
help="require more revision in reasoner (default : 0)",
)
kb_type = parser.add_mutually_exclusive_group()
kb_type.add_argument(
"--prolog", action="store_true", default=False, help="use PrologKB (default: False)"
)
kb_type.add_argument(
"--ground", action="store_true", default=False, help="use GroundKB (default: False)"
)

args = parser.parse_args()

# Build logger
print_log("Abductive Learning on the MNIST Addition example.", logger="current")

### Working with Data
print_log("Working with Data.", logger="current")
train_data = get_dataset(train=True, get_pseudo_label=True)
test_data = get_dataset(train=False, get_pseudo_label=True)

### Building the Learning Part
print_log("Building the Learning Part.", logger="current")
# Build necessary components for BasicNN
model=FixMatch(network=LeNet5(), threshold=0.95,lambda_u=1.0,mu=7,T=0.5,epoch=1,num_it_epoch=2**20,num_it_total=2**20,device='cuda:0')

loss_fn = nn.CrossEntropyLoss(label_smoothing=0.2)
optimizer_dict = dict(optimizer=RMSprop, lr=0.0003, alpha=0.9)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
scheduler_dict = dict(scheduler=lr_scheduler.OneCycleLR, max_lr=0.0003, pct_start=0.15, total_steps=200)

converter = ModelConverter()
base_model = converter.convert_lambdalearn_to_basicnn(model, loss_fn=loss_fn, optimizer_dict=optimizer_dict, scheduler_dict=scheduler_dict, device=device)

# Build ABLModel
model = ABLModel(base_model)

### Building the Reasoning Part
print_log("Building the Reasoning Part.", logger="current")
# Build knowledge base
if args.prolog:
kb = PrologKB(pseudo_label_list=list(range(10)), pl_file="add.pl")
elif args.ground:
kb = AddGroundKB()
else:
kb = AddKB()

# Create reasoner
reasoner = Reasoner(
kb, max_revision=args.max_revision, require_more_revision=args.require_more_revision
)

### Building Evaluation Metrics
print_log("Building Evaluation Metrics.", logger="current")
metric_list = [SymbolAccuracy(prefix="mnist_add"), ReasoningMetric(kb=kb, prefix="mnist_add")]

### Bridge Learning and Reasoning
print_log("Bridge Learning and Reasoning.", logger="current")
bridge = SimpleBridge(model, reasoner, metric_list)

# Retrieve the directory of the Log file and define the directory for saving the model weights.
log_dir = ABLLogger.get_current_instance().log_dir
weights_dir = osp.join(log_dir, "weights")

# Train and Test
bridge.train(
train_data,
loops=args.loops,
segment_size=args.segment_size,
save_interval=args.save_interval,
save_dir=weights_dir,
)
bridge.test(test_data)


if __name__ == "__main__":
main()

Loading…
Cancel
Save