import argparse import os.path as osp import numpy as np import torch from torch import nn from ablkit.bridge import SimpleBridge from ablkit.data.evaluation import ReasoningMetric, SymbolAccuracy from ablkit.learning import ABLModel, BasicNN from ablkit.reasoning import GroundKB, KBBase, Reasoner from ablkit.utils import ABLLogger, print_log from datasets import get_dataset from models.nn import SymbolNet class HwfKB(KBBase): def __init__( self, pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", "+", "-", "*", "/"], max_err=1e-10, ): super().__init__(pseudo_label_list, max_err) def _valid_candidate(self, formula): if len(formula) % 2 == 0: return False for i in range(len(formula)): if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]: return False if i % 2 != 0 and formula[i] not in ["+", "-", "*", "/"]: return False return True # Implement the deduction function def logic_forward(self, formula): if not self._valid_candidate(formula): return np.inf return eval("".join(formula)) class HwfGroundKB(GroundKB): def __init__( self, pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", "+", "-", "*", "/"], GKB_len_list=[1, 3, 5, 7], max_err=1e-10, ): super().__init__(pseudo_label_list, GKB_len_list, max_err) def _valid_candidate(self, formula): if len(formula) % 2 == 0: return False for i in range(len(formula)): if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]: return False if i % 2 != 0 and formula[i] not in ["+", "-", "*", "/"]: return False return True # Implement the deduction function def logic_forward(self, formula): if not self._valid_candidate(formula): return np.inf return eval("".join(formula)) def main(): parser = argparse.ArgumentParser(description="Handwritten Formula example") parser.add_argument( "--no-cuda", action="store_true", default=False, help="disables CUDA training" ) parser.add_argument( "--epochs", type=int, default=3, help="number of epochs in each learning loop iteration (default : 3)", ) parser.add_argument( "--label-smoothing", type=float, default=0.2, help="label smoothing in cross entropy loss (default : 0.2)" ) parser.add_argument( "--lr", type=float, default=1e-3, help="base model learning rate (default : 0.001)" ) parser.add_argument( "--batch-size", type=int, default=128, help="base model batch size (default : 128)" ) parser.add_argument( "--loops", type=int, default=3, help="number of loop iterations (default : 3)" ) parser.add_argument( "--segment_size", type=int, default=1000, help="segment size (default : 1000)" ) parser.add_argument("--save_interval", type=int, default=1, help="save interval (default : 1)") parser.add_argument( "--max-revision", type=int, default=-1, help="maximum revision in reasoner (default : -1)" ) parser.add_argument( "--require-more-revision", type=int, default=0, help="require more revision in reasoner (default : 0)", ) parser.add_argument( "--ground", action="store_true", default=False, help="use GroundKB (default: False)" ) parser.add_argument( "--max-err", type=float, default=1e-10, help="max tolerance during abductive reasoning (default : 1e-10)", ) args = parser.parse_args() # Build logger print_log("Abductive Learning on the HWF example.", logger="current") # -- Working with Data ------------------------------ print_log("Working with Data.", logger="current") train_data = get_dataset(train=True, get_pseudo_label=True) test_data = get_dataset(train=False, get_pseudo_label=True) # -- Building the Learning Part --------------------- print_log("Building the Learning Part.", logger="current") # Build necessary components for BasicNN cls = SymbolNet(num_classes=13, image_size=(45, 45, 1)) loss_fn = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) optimizer = torch.optim.Adam(cls.parameters(), lr=args.lr) use_cuda = not args.no_cuda and torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") # Build BasicNN base_model = BasicNN( cls, loss_fn, optimizer, device=device, batch_size=args.batch_size, num_epochs=args.epochs, ) # Build ABLModel model = ABLModel(base_model) # -- Building the Reasoning Part -------------------- print_log("Building the Reasoning Part.", logger="current") # Build knowledge base if args.ground: kb = HwfGroundKB() else: kb = HwfKB() # Create reasoner reasoner = Reasoner( kb, max_revision=args.max_revision, require_more_revision=args.require_more_revision ) # -- Building Evaluation Metrics -------------------- print_log("Building Evaluation Metrics.", logger="current") metric_list = [SymbolAccuracy(prefix="hwf"), ReasoningMetric(kb=kb, prefix="hwf")] # -- Bridging Learning and Reasoning ---------------- print_log("Bridge Learning and Reasoning.", logger="current") bridge = SimpleBridge(model, reasoner, metric_list) # Retrieve the directory of the Log file and define the directory for saving the model weights. log_dir = ABLLogger.get_current_instance().log_dir weights_dir = osp.join(log_dir, "weights") # Train and Test bridge.train( train_data, val_data=test_data, loops=args.loops, segment_size=args.segment_size, save_interval=args.save_interval, save_dir=weights_dir, ) bridge.test(test_data) if __name__ == "__main__": main()