| @@ -46,13 +46,20 @@ def main(): | |||
| ) | |||
| args = parser.parse_args() | |||
| # Build logger | |||
| print_log("Abductive Learning on the HED example.", logger="current") | |||
| ### Working with Data | |||
| print_log("Working with Data.", logger="current") | |||
| total_train_data = get_dataset(train=True) | |||
| train_data, val_data = split_equation(total_train_data, 3, 1) | |||
| test_data = get_dataset(train=False) | |||
| ### Building the Learning Part | |||
| print_log("Building the Learning Part.", logger="current") | |||
| # Build necessary components for BasicNN | |||
| cls = SymbolNet(num_classes=4) | |||
| loss_fn = nn.CrossEntropyLoss() | |||
| @@ -75,6 +82,8 @@ def main(): | |||
| model = ABLModel(base_model) | |||
| ### Building the Reasoning Part | |||
| print_log("Building the Reasoning Part.", logger="current") | |||
| # Build knowledge base | |||
| kb = HedKB() | |||
| @@ -82,14 +91,13 @@ def main(): | |||
| reasoner = HedReasoner(kb, dist_func="hamming", use_zoopt=True, max_revision=args.max_revision) | |||
| ### Building Evaluation Metrics | |||
| print_log("Building Evaluation Metrics.", logger="current") | |||
| metric_list = [ConsistencyMetric(kb=kb)] | |||
| ### Bridge Learning and Reasoning | |||
| print_log("Bridge Learning and Reasoning.", logger="current") | |||
| bridge = HedBridge(model, reasoner, metric_list) | |||
| # Build logger | |||
| print_log("Abductive Learning on the HED example.", logger="current") | |||
| # Retrieve the directory of the Log file and define the directory for saving the model weights. | |||
| log_dir = ABLLogger.get_current_instance().log_dir | |||
| weights_dir = osp.join(log_dir, "weights") | |||
| @@ -113,12 +113,19 @@ def main(): | |||
| ) | |||
| args = parser.parse_args() | |||
| # Build logger | |||
| print_log("Abductive Learning on the HWF example.", logger="current") | |||
| ### Working with Data | |||
| print_log("Working with Data.", logger="current") | |||
| train_data = get_dataset(train=True, get_pseudo_label=True) | |||
| test_data = get_dataset(train=False, get_pseudo_label=True) | |||
| ### Building the Learning Part | |||
| print_log("Building the Learning Part.", logger="current") | |||
| # Build necessary components for BasicNN | |||
| cls = SymbolNet(num_classes=13, image_size=(45, 45, 1)) | |||
| loss_fn = nn.CrossEntropyLoss() | |||
| @@ -140,6 +147,8 @@ def main(): | |||
| model = ABLModel(base_model) | |||
| ### Building the Reasoning Part | |||
| print_log("Building the Reasoning Part.", logger="current") | |||
| # Build knowledge base | |||
| if args.ground: | |||
| kb = HwfGroundKB() | |||
| @@ -152,14 +161,13 @@ def main(): | |||
| ) | |||
| ### Building Evaluation Metrics | |||
| print_log("Building Evaluation Metrics.", logger="current") | |||
| metric_list = [SymbolAccuracy(prefix="hwf"), ReasoningMetric(kb=kb, prefix="hwf")] | |||
| ### Bridge Learning and Reasoning | |||
| print_log("Bridge Learning and Reasoning.", logger="current") | |||
| bridge = SimpleBridge(model, reasoner, metric_list) | |||
| # Build logger | |||
| print_log("Abductive Learning on the HWF example.", logger="current") | |||
| # Retrieve the directory of the Log file and define the directory for saving the model weights. | |||
| log_dir = ABLLogger.get_current_instance().log_dir | |||
| weights_dir = osp.join(log_dir, "weights") | |||
| @@ -78,11 +78,17 @@ def main(): | |||
| args = parser.parse_args() | |||
| # Build logger | |||
| print_log("Abductive Learning on the MNIST Addition example.", logger="current") | |||
| ### Working with Data | |||
| print_log("Working with Data.", logger="current") | |||
| train_data = get_dataset(train=True, get_pseudo_label=True) | |||
| test_data = get_dataset(train=False, get_pseudo_label=True) | |||
| ### Building the Learning Part | |||
| print_log("Building the Learning Part.", logger="current") | |||
| # Build necessary components for BasicNN | |||
| cls = LeNet5(num_classes=10) | |||
| loss_fn = nn.CrossEntropyLoss(label_smoothing=0.2) | |||
| @@ -112,6 +118,8 @@ def main(): | |||
| model = ABLModel(base_model) | |||
| ### Building the Reasoning Part | |||
| print_log("Building the Reasoning Part.", logger="current") | |||
| # Build knowledge base | |||
| if args.prolog: | |||
| kb = PrologKB(pseudo_label_list=list(range(10)), pl_file="add.pl") | |||
| @@ -126,14 +134,13 @@ def main(): | |||
| ) | |||
| ### Building Evaluation Metrics | |||
| print_log("Building Evaluation Metrics.", logger="current") | |||
| metric_list = [SymbolAccuracy(prefix="mnist_add"), ReasoningMetric(kb=kb, prefix="mnist_add")] | |||
| ### Bridge Learning and Reasoning | |||
| print_log("Bridge Learning and Reasoning.", logger="current") | |||
| bridge = SimpleBridge(model, reasoner, metric_list) | |||
| # Build logger | |||
| print_log("Abductive Learning on the MNIST Addition example.", logger="current") | |||
| # Retrieve the directory of the Log file and define the directory for saving the model weights. | |||
| log_dir = ABLLogger.get_current_instance().log_dir | |||
| weights_dir = osp.join(log_dir, "weights") | |||
| @@ -30,8 +30,13 @@ def main(): | |||
| "--loops", type=int, default=3, help="number of loop iterations (default : 3)" | |||
| ) | |||
| args = parser.parse_args() | |||
| # Build logger | |||
| print_log("Abductive Learning on the ZOO example.", logger="current") | |||
| ### Working with Data | |||
| print_log("Working with Data.", logger="current") | |||
| X, y = load_and_preprocess_dataset(dataset_id=62) | |||
| X_label, y_label, X_unlabel, y_unlabel, X_test, y_test = split_dataset(X, y, test_size=0.3) | |||
| label_data = transform_tab_data(X_label, y_label) | |||
| @@ -39,12 +44,17 @@ def main(): | |||
| train_data = transform_tab_data(X_unlabel, y_unlabel) | |||
| ### Building the Learning Part | |||
| print_log("Building the Learning Part.", logger="current") | |||
| # Build base model | |||
| base_model = RandomForestClassifier() | |||
| # Build ABLModel | |||
| model = ABLModel(base_model) | |||
| ### Building the Reasoning Part | |||
| print_log("Building the Reasoning Part.", logger="current") | |||
| # Build knowledge base | |||
| kb = ZooKB() | |||
| @@ -52,16 +62,17 @@ def main(): | |||
| reasoner = Reasoner(kb, dist_func=consitency) | |||
| ### Building Evaluation Metrics | |||
| print_log("Building Evaluation Metrics.", logger="current") | |||
| metric_list = [SymbolAccuracy(prefix="zoo"), ReasoningMetric(kb=kb, prefix="zoo")] | |||
| # Build logger | |||
| print_log("Abductive Learning on the ZOO example.", logger="current") | |||
| log_dir = ABLLogger.get_current_instance().log_dir | |||
| weights_dir = osp.join(log_dir, "weights") | |||
| ### Bridging learning and reasoning | |||
| print_log("Bridge Learning and Reasoning.", logger="current") | |||
| bridge = SimpleBridge(model, reasoner, metric_list) | |||
| # Retrieve the directory of the Log file and define the directory for saving the model weights. | |||
| log_dir = ABLLogger.get_current_instance().log_dir | |||
| weights_dir = osp.join(log_dir, "weights") | |||
| # Performing training and testing | |||
| print_log("------- Use labeled data to pretrain the model -----------", logger="current") | |||
| base_model.fit(X_label, y_label) | |||