|
|
|
@@ -6,15 +6,16 @@ |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"import torch\n", |
|
|
|
"import numpy as np\n", |
|
|
|
"import torch.nn as nn\n", |
|
|
|
"import torch\n", |
|
|
|
"import os.path as osp\n", |
|
|
|
"\n", |
|
|
|
"from abl.reasoning import ReasonerBase, KBBase\n", |
|
|
|
"from abl.learning import BasicNN, ABLModel\n", |
|
|
|
"from abl.bridge import SimpleBridge\n", |
|
|
|
"from abl.evaluation import SymbolMetric, ABLMetric\n", |
|
|
|
"from abl.utils import ABLLogger\n", |
|
|
|
"from abl.evaluation import SymbolMetric, SemanticsMetric\n", |
|
|
|
"from abl.utils import ABLLogger, print_log\n", |
|
|
|
"\n", |
|
|
|
"from examples.models.nn import SymbolNet\n", |
|
|
|
"from datasets.get_hwf import get_hwf" |
|
|
|
@@ -26,8 +27,12 @@ |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"# Initialize logger\n", |
|
|
|
"logger = ABLLogger.get_instance(\"abl\")" |
|
|
|
"# Initialize logger and print basic information\n", |
|
|
|
"print_log(\"Abductive Learning on the HWF example.\", logger=\"current\")\n", |
|
|
|
"\n", |
|
|
|
"# Retrieve the directory of the Log file and define the directory for saving the model weights.\n", |
|
|
|
"log_dir = ABLLogger.get_current_instance().log_dir\n", |
|
|
|
"weights_dir = osp.join(log_dir, \"weights\")" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
@@ -108,12 +113,12 @@ |
|
|
|
"# Initialize BasicNN\n", |
|
|
|
"# The function of BasicNN is to wrap NN models into the form of an sklearn estimator\n", |
|
|
|
"base_model = BasicNN(\n", |
|
|
|
" cls,\n", |
|
|
|
" criterion,\n", |
|
|
|
" optimizer,\n", |
|
|
|
" device,\n", |
|
|
|
" model=cls,\n", |
|
|
|
" criterion=criterion,\n", |
|
|
|
" optimizer=optimizer,\n", |
|
|
|
" device=device,\n", |
|
|
|
" save_interval=1,\n", |
|
|
|
" save_dir=logger.save_dir,\n", |
|
|
|
" save_dir=weights_dir,\n", |
|
|
|
" batch_size=128,\n", |
|
|
|
" num_epochs=3,\n", |
|
|
|
")" |
|
|
|
@@ -146,7 +151,7 @@ |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"# Add metric\n", |
|
|
|
"metric = [SymbolMetric(prefix=\"hwf\"), ABLMetric(prefix=\"hwf\")]" |
|
|
|
"metric_list = [SymbolMetric(prefix=\"hwf\"), SemanticsMetric(prefix=\"hwf\")]" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
@@ -164,8 +169,8 @@ |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"# Get training and testing data\n", |
|
|
|
"train_data = get_hwf(train=True, get_pseudo_label=True)\n", |
|
|
|
"test_data = get_hwf(train=False, get_pseudo_label=True)" |
|
|
|
"train_data = get_hwf(train=True, get_gt_pseudo_label=True)\n", |
|
|
|
"test_data = get_hwf(train=False, get_gt_pseudo_label=True)" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
@@ -182,7 +187,7 @@ |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"bridge = SimpleBridge(model, abducer, metric)" |
|
|
|
"bridge = SimpleBridge(model=model, abducer=abducer, metric_list=metric_list)" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
@@ -220,7 +225,7 @@ |
|
|
|
"name": "python", |
|
|
|
"nbconvert_exporter": "python", |
|
|
|
"pygments_lexer": "ipython3", |
|
|
|
"version": "3.8.13" |
|
|
|
"version": "3.8.16" |
|
|
|
}, |
|
|
|
"orig_nbformat": 4, |
|
|
|
"vscode": { |
|
|
|
|