From 95c52ec966fc0adefcf356aafb9501b3310e21f0 Mon Sep 17 00:00:00 2001 From: troyyyyy Date: Sat, 30 Dec 2023 08:50:08 +0800 Subject: [PATCH] [ENH] add log in examples --- examples/hed/main.py | 14 +++++++++++--- examples/hwf/main.py | 14 +++++++++++--- examples/mnist_add/main.py | 13 ++++++++++--- examples/zoo/main.py | 21 ++++++++++++++++----- 4 files changed, 48 insertions(+), 14 deletions(-) diff --git a/examples/hed/main.py b/examples/hed/main.py index f4e7564..984ff5c 100644 --- a/examples/hed/main.py +++ b/examples/hed/main.py @@ -46,13 +46,20 @@ def main(): ) args = parser.parse_args() + + # Build logger + print_log("Abductive Learning on the HED example.", logger="current") ### Working with Data + print_log("Working with Data.", logger="current") + total_train_data = get_dataset(train=True) train_data, val_data = split_equation(total_train_data, 3, 1) test_data = get_dataset(train=False) ### Building the Learning Part + print_log("Building the Learning Part.", logger="current") + # Build necessary components for BasicNN cls = SymbolNet(num_classes=4) loss_fn = nn.CrossEntropyLoss() @@ -75,6 +82,8 @@ def main(): model = ABLModel(base_model) ### Building the Reasoning Part + print_log("Building the Reasoning Part.", logger="current") + # Build knowledge base kb = HedKB() @@ -82,14 +91,13 @@ def main(): reasoner = HedReasoner(kb, dist_func="hamming", use_zoopt=True, max_revision=args.max_revision) ### Building Evaluation Metrics + print_log("Building Evaluation Metrics.", logger="current") metric_list = [ConsistencyMetric(kb=kb)] ### Bridge Learning and Reasoning + print_log("Bridge Learning and Reasoning.", logger="current") bridge = HedBridge(model, reasoner, metric_list) - # Build logger - print_log("Abductive Learning on the HED example.", logger="current") - # Retrieve the directory of the Log file and define the directory for saving the model weights. log_dir = ABLLogger.get_current_instance().log_dir weights_dir = osp.join(log_dir, "weights") diff --git a/examples/hwf/main.py b/examples/hwf/main.py index 963fa15..83c60e9 100644 --- a/examples/hwf/main.py +++ b/examples/hwf/main.py @@ -113,12 +113,19 @@ def main(): ) args = parser.parse_args() + + # Build logger + print_log("Abductive Learning on the HWF example.", logger="current") ### Working with Data + print_log("Working with Data.", logger="current") + train_data = get_dataset(train=True, get_pseudo_label=True) test_data = get_dataset(train=False, get_pseudo_label=True) ### Building the Learning Part + print_log("Building the Learning Part.", logger="current") + # Build necessary components for BasicNN cls = SymbolNet(num_classes=13, image_size=(45, 45, 1)) loss_fn = nn.CrossEntropyLoss() @@ -140,6 +147,8 @@ def main(): model = ABLModel(base_model) ### Building the Reasoning Part + print_log("Building the Reasoning Part.", logger="current") + # Build knowledge base if args.ground: kb = HwfGroundKB() @@ -152,14 +161,13 @@ def main(): ) ### Building Evaluation Metrics + print_log("Building Evaluation Metrics.", logger="current") metric_list = [SymbolAccuracy(prefix="hwf"), ReasoningMetric(kb=kb, prefix="hwf")] ### Bridge Learning and Reasoning + print_log("Bridge Learning and Reasoning.", logger="current") bridge = SimpleBridge(model, reasoner, metric_list) - # Build logger - print_log("Abductive Learning on the HWF example.", logger="current") - # Retrieve the directory of the Log file and define the directory for saving the model weights. log_dir = ABLLogger.get_current_instance().log_dir weights_dir = osp.join(log_dir, "weights") diff --git a/examples/mnist_add/main.py b/examples/mnist_add/main.py index b74b82e..0616fc5 100644 --- a/examples/mnist_add/main.py +++ b/examples/mnist_add/main.py @@ -78,11 +78,17 @@ def main(): args = parser.parse_args() + # Build logger + print_log("Abductive Learning on the MNIST Addition example.", logger="current") + ### Working with Data + print_log("Working with Data.", logger="current") train_data = get_dataset(train=True, get_pseudo_label=True) test_data = get_dataset(train=False, get_pseudo_label=True) ### Building the Learning Part + print_log("Building the Learning Part.", logger="current") + # Build necessary components for BasicNN cls = LeNet5(num_classes=10) loss_fn = nn.CrossEntropyLoss(label_smoothing=0.2) @@ -112,6 +118,8 @@ def main(): model = ABLModel(base_model) ### Building the Reasoning Part + print_log("Building the Reasoning Part.", logger="current") + # Build knowledge base if args.prolog: kb = PrologKB(pseudo_label_list=list(range(10)), pl_file="add.pl") @@ -126,14 +134,13 @@ def main(): ) ### Building Evaluation Metrics + print_log("Building Evaluation Metrics.", logger="current") metric_list = [SymbolAccuracy(prefix="mnist_add"), ReasoningMetric(kb=kb, prefix="mnist_add")] ### Bridge Learning and Reasoning + print_log("Bridge Learning and Reasoning.", logger="current") bridge = SimpleBridge(model, reasoner, metric_list) - # Build logger - print_log("Abductive Learning on the MNIST Addition example.", logger="current") - # Retrieve the directory of the Log file and define the directory for saving the model weights. log_dir = ABLLogger.get_current_instance().log_dir weights_dir = osp.join(log_dir, "weights") diff --git a/examples/zoo/main.py b/examples/zoo/main.py index 26bdc66..4ece65f 100644 --- a/examples/zoo/main.py +++ b/examples/zoo/main.py @@ -30,8 +30,13 @@ def main(): "--loops", type=int, default=3, help="number of loop iterations (default : 3)" ) args = parser.parse_args() + + # Build logger + print_log("Abductive Learning on the ZOO example.", logger="current") ### Working with Data + print_log("Working with Data.", logger="current") + X, y = load_and_preprocess_dataset(dataset_id=62) X_label, y_label, X_unlabel, y_unlabel, X_test, y_test = split_dataset(X, y, test_size=0.3) label_data = transform_tab_data(X_label, y_label) @@ -39,12 +44,17 @@ def main(): train_data = transform_tab_data(X_unlabel, y_unlabel) ### Building the Learning Part + print_log("Building the Learning Part.", logger="current") + + # Build base model base_model = RandomForestClassifier() # Build ABLModel model = ABLModel(base_model) ### Building the Reasoning Part + print_log("Building the Reasoning Part.", logger="current") + # Build knowledge base kb = ZooKB() @@ -52,16 +62,17 @@ def main(): reasoner = Reasoner(kb, dist_func=consitency) ### Building Evaluation Metrics + print_log("Building Evaluation Metrics.", logger="current") metric_list = [SymbolAccuracy(prefix="zoo"), ReasoningMetric(kb=kb, prefix="zoo")] - # Build logger - print_log("Abductive Learning on the ZOO example.", logger="current") - log_dir = ABLLogger.get_current_instance().log_dir - weights_dir = osp.join(log_dir, "weights") - ### Bridging learning and reasoning + print_log("Bridge Learning and Reasoning.", logger="current") bridge = SimpleBridge(model, reasoner, metric_list) + # Retrieve the directory of the Log file and define the directory for saving the model weights. + log_dir = ABLLogger.get_current_instance().log_dir + weights_dir = osp.join(log_dir, "weights") + # Performing training and testing print_log("------- Use labeled data to pretrain the model -----------", logger="current") base_model.fit(X_label, y_label)