| @@ -17,8 +17,8 @@ class ModelConverter: | |||
| self, | |||
| lambdalearn_model, | |||
| loss_fn: torch.nn.Module, | |||
| optimizer: torch.optim.Optimizer, | |||
| scheduler: Optional[Callable[..., Any]] = None, | |||
| optimizer_dict: dict, | |||
| scheduler_dict: Optional[dict] = None, | |||
| device: Optional[torch.device] = None, | |||
| batch_size: int = 32, | |||
| num_epochs: int = 1, | |||
| @@ -39,11 +39,13 @@ class ModelConverter: | |||
| The LambdaLearn model to be converted. | |||
| loss_fn : torch.nn.Module | |||
| The loss function used for training. | |||
| optimizer : torch.optim.Optimizer | |||
| The optimizer used for training. | |||
| scheduler : Callable[..., Any], optional | |||
| The learning rate scheduler used for training, which will be called | |||
| at the end of each run of the ``fit`` method. It should implement the | |||
| optimizer_dict : dict | |||
| The dict contains necessary parameters to construct a optimizer used for training. | |||
| The optimizer class is specified by the ``optimizer`` key. | |||
| scheduler_dict : dict, optional | |||
| The dict contains necessary parameters to construct a learning rate scheduler used | |||
| for training, which will be called at the end of each run of the ``fit`` method. | |||
| The scheduler class is specified by the ``scheduler`` key. It should implement the | |||
| ``step`` method, by default None. | |||
| device : torch.device, optional | |||
| The device on which the model will be trained or used for prediction, | |||
| @@ -75,7 +77,7 @@ class ModelConverter: | |||
| The converted ABLModel instance. | |||
| ''' | |||
| if isinstance(lambdalearn_model, DeepModelMixin): | |||
| base_model = self.convert_lambdalearn_to_basicnn(lambdalearn_model, loss_fn, optimizer, scheduler, device, batch_size, num_epochs, stop_loss, num_workers, save_interval, save_dir, train_transform, test_transform, collate_fn) | |||
| base_model = self.convert_lambdalearn_to_basicnn(lambdalearn_model, loss_fn, optimizer_dict, scheduler_dict, device, batch_size, num_epochs, stop_loss, num_workers, save_interval, save_dir, train_transform, test_transform, collate_fn) | |||
| return ABLModel(base_model) | |||
| if not (hasattr(lambdalearn_model, "fit") and hasattr(lambdalearn_model, "predict")): | |||
| @@ -88,8 +90,8 @@ class ModelConverter: | |||
| self, | |||
| lambdalearn_model: DeepModelMixin, | |||
| loss_fn: torch.nn.Module, | |||
| optimizer: torch.optim.Optimizer, | |||
| scheduler: Optional[Callable[..., Any]] = None, | |||
| optimizer_dict: dict, | |||
| scheduler_dict: Optional[dict] = None, | |||
| device: Optional[torch.device] = None, | |||
| batch_size: int = 32, | |||
| num_epochs: int = 1, | |||
| @@ -110,11 +112,12 @@ class ModelConverter: | |||
| The LambdaLearn model to be converted. | |||
| loss_fn : torch.nn.Module | |||
| The loss function used for training. | |||
| optimizer : torch.optim.Optimizer | |||
| The optimizer used for training. | |||
| scheduler : Callable[..., Any], optional | |||
| The learning rate scheduler used for training, which will be called | |||
| at the end of each run of the ``fit`` method. It should implement the | |||
| optimizer_dict : dict | |||
| The dict contains necessary parameters to construct a optimizer used for training. | |||
| scheduler_dict : dict, optional | |||
| The dict contains necessary parameters to construct a learning rate scheduler used | |||
| for training, which will be called at the end of each run of the ``fit`` method. | |||
| The scheduler class is specified by the ``scheduler`` key. It should implement the | |||
| ``step`` method, by default None. | |||
| device : torch.device, optional | |||
| The device on which the model will be trained or used for prediction, | |||
| @@ -150,6 +153,15 @@ class ModelConverter: | |||
| raise NotImplementedError(f"Expected lambdalearn_model.network to be a torch.nn.Module, but got {type(lambdalearn_model.network)}") | |||
| # Only use the network part and device of the lambdalearn model | |||
| network = copy.deepcopy(lambdalearn_model.network) | |||
| optimizer_class = optimizer_dict["optimizer"] | |||
| optimizer_dict.pop("optimizer") | |||
| optimizer = optimizer_class(network.parameters(), **optimizer_dict) | |||
| if scheduler_dict is not None: | |||
| scheduler_class = scheduler_dict["scheduler"] | |||
| scheduler_dict.pop("scheduler") | |||
| scheduler = scheduler_class(optimizer, **scheduler_dict) | |||
| else: | |||
| scheduler = None | |||
| device = lambdalearn_model.device if device is None else device | |||
| base_model = BasicNN(model=network, loss_fn=loss_fn, optimizer=optimizer, scheduler=scheduler, device=device, batch_size=batch_size, num_epochs=num_epochs, stop_loss=stop_loss, num_workers=num_workers, save_interval=save_interval, save_dir=save_dir, train_transform=train_transform, test_transform=test_transform, collate_fn=collate_fn) | |||
| return base_model | |||
| @@ -0,0 +1,148 @@ | |||
| import argparse | |||
| import os.path as osp | |||
| import torch | |||
| from torch import nn | |||
| from torch.optim import RMSprop, lr_scheduler | |||
| from lambdaLearn.Algorithm.AbductiveLearning.bridge import SimpleBridge | |||
| from lambdaLearn.Algorithm.AbductiveLearning.data.evaluation import ReasoningMetric, SymbolAccuracy | |||
| from lambdaLearn.Algorithm.AbductiveLearning.learning import ABLModel | |||
| from lambdaLearn.Algorithm.AbductiveLearning.learning.model_converter import ModelConverter | |||
| from lambdaLearn.Algorithm.AbductiveLearning.reasoning import GroundKB, KBBase, PrologKB, Reasoner | |||
| from lambdaLearn.Algorithm.AbductiveLearning.utils import ABLLogger, print_log | |||
| from lambdaLearn.Algorithm.Classification.FixMatch import FixMatch | |||
| from datasets import get_dataset | |||
| from models.nn import LeNet5 | |||
| class AddKB(KBBase): | |||
| def __init__(self, pseudo_label_list=list(range(10))): | |||
| super().__init__(pseudo_label_list) | |||
| def logic_forward(self, nums): | |||
| return sum(nums) | |||
| class AddGroundKB(GroundKB): | |||
| def __init__(self, pseudo_label_list=list(range(10)), GKB_len_list=[2]): | |||
| super().__init__(pseudo_label_list, GKB_len_list) | |||
| def logic_forward(self, nums): | |||
| return sum(nums) | |||
| def main(): | |||
| parser = argparse.ArgumentParser(description="MNIST Addition example") | |||
| parser.add_argument( | |||
| "--no-cuda", action="store_true", default=False, help="disables CUDA training" | |||
| ) | |||
| parser.add_argument( | |||
| "--epochs", | |||
| type=int, | |||
| default=1, | |||
| help="number of epochs in each learning loop iteration (default : 1)", | |||
| ) | |||
| parser.add_argument( | |||
| "--lr", type=float, default=3e-4, help="base model learning rate (default : 0.0003)" | |||
| ) | |||
| parser.add_argument("--alpha", type=float, default=0.9, help="alpha in RMSprop (default : 0.9)") | |||
| parser.add_argument( | |||
| "--batch-size", type=int, default=32, help="base model batch size (default : 32)" | |||
| ) | |||
| parser.add_argument( | |||
| "--loops", type=int, default=2, help="number of loop iterations (default : 2)" | |||
| ) | |||
| parser.add_argument( | |||
| "--segment_size", type=int, default=0.01, help="segment size (default : 0.01)" | |||
| ) | |||
| parser.add_argument("--save_interval", type=int, default=1, help="save interval (default : 1)") | |||
| parser.add_argument( | |||
| "--max-revision", | |||
| type=int, | |||
| default=-1, | |||
| help="maximum revision in reasoner (default : -1)", | |||
| ) | |||
| parser.add_argument( | |||
| "--require-more-revision", | |||
| type=int, | |||
| default=0, | |||
| help="require more revision in reasoner (default : 0)", | |||
| ) | |||
| kb_type = parser.add_mutually_exclusive_group() | |||
| kb_type.add_argument( | |||
| "--prolog", action="store_true", default=False, help="use PrologKB (default: False)" | |||
| ) | |||
| kb_type.add_argument( | |||
| "--ground", action="store_true", default=False, help="use GroundKB (default: False)" | |||
| ) | |||
| args = parser.parse_args() | |||
| # Build logger | |||
| print_log("Abductive Learning on the MNIST Addition example.", logger="current") | |||
| ### Working with Data | |||
| print_log("Working with Data.", logger="current") | |||
| train_data = get_dataset(train=True, get_pseudo_label=True) | |||
| test_data = get_dataset(train=False, get_pseudo_label=True) | |||
| ### Building the Learning Part | |||
| print_log("Building the Learning Part.", logger="current") | |||
| # Build necessary components for BasicNN | |||
| model=FixMatch(network=LeNet5(), threshold=0.95,lambda_u=1.0,mu=7,T=0.5,epoch=1,num_it_epoch=2**20,num_it_total=2**20,device='cuda:0') | |||
| loss_fn = nn.CrossEntropyLoss(label_smoothing=0.2) | |||
| optimizer_dict = dict(optimizer=RMSprop, lr=0.0003, alpha=0.9) | |||
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |||
| scheduler_dict = dict(scheduler=lr_scheduler.OneCycleLR, max_lr=0.0003, pct_start=0.15, total_steps=200) | |||
| converter = ModelConverter() | |||
| base_model = converter.convert_lambdalearn_to_basicnn(model, loss_fn=loss_fn, optimizer_dict=optimizer_dict, scheduler_dict=scheduler_dict, device=device) | |||
| # Build ABLModel | |||
| model = ABLModel(base_model) | |||
| ### Building the Reasoning Part | |||
| print_log("Building the Reasoning Part.", logger="current") | |||
| # Build knowledge base | |||
| if args.prolog: | |||
| kb = PrologKB(pseudo_label_list=list(range(10)), pl_file="add.pl") | |||
| elif args.ground: | |||
| kb = AddGroundKB() | |||
| else: | |||
| kb = AddKB() | |||
| # Create reasoner | |||
| reasoner = Reasoner( | |||
| kb, max_revision=args.max_revision, require_more_revision=args.require_more_revision | |||
| ) | |||
| ### Building Evaluation Metrics | |||
| print_log("Building Evaluation Metrics.", logger="current") | |||
| metric_list = [SymbolAccuracy(prefix="mnist_add"), ReasoningMetric(kb=kb, prefix="mnist_add")] | |||
| ### Bridge Learning and Reasoning | |||
| print_log("Bridge Learning and Reasoning.", logger="current") | |||
| bridge = SimpleBridge(model, reasoner, metric_list) | |||
| # Retrieve the directory of the Log file and define the directory for saving the model weights. | |||
| log_dir = ABLLogger.get_current_instance().log_dir | |||
| weights_dir = osp.join(log_dir, "weights") | |||
| # Train and Test | |||
| bridge.train( | |||
| train_data, | |||
| loops=args.loops, | |||
| segment_size=args.segment_size, | |||
| save_interval=args.save_interval, | |||
| save_dir=weights_dir, | |||
| ) | |||
| bridge.test(test_data) | |||
| if __name__ == "__main__": | |||
| main() | |||