Browse Source

[MNT] resolve all comments in hwf_example.ipynb

pull/3/head
Gao Enhao 2 years ago
parent
commit
2cf678910b
2 changed files with 29 additions and 20 deletions
  1. +9
    -5
      abl/utils/logger.py
  2. +20
    -15
      examples/hwf/hwf_example.ipynb

+ 9
- 5
abl/utils/logger.py View File

@@ -198,11 +198,11 @@ class ABLLogger(Logger, ManagerMixin):
import time
local_time = time.strftime("%Y%m%d_%H_%M_%S", time.localtime())

save_dir = os.path.join("results", local_time)
self.save_dir = save_dir
if not os.path.exists(save_dir):
os.makedirs(save_dir)
log_file = osp.join(save_dir, local_time + ".log")
_log_dir = os.path.join("results", local_time)
self._log_dir = _log_dir
if not os.path.exists(_log_dir):
os.makedirs(_log_dir)
log_file = osp.join(_log_dir, local_time + ".log")

file_handler = logging.FileHandler(log_file, file_mode)
file_handler.setFormatter(
@@ -216,6 +216,10 @@ class ABLLogger(Logger, ManagerMixin):
@property
def log_file(self):
return self._log_file
@property
def log_dir(self):
return self._log_dir

@classmethod
def get_current_instance(cls) -> "ABLLogger":


+ 20
- 15
examples/hwf/hwf_example.ipynb View File

@@ -6,15 +6,16 @@
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import numpy as np\n",
"import torch.nn as nn\n",
"import torch\n",
"import os.path as osp\n",
"\n",
"from abl.reasoning import ReasonerBase, KBBase\n",
"from abl.learning import BasicNN, ABLModel\n",
"from abl.bridge import SimpleBridge\n",
"from abl.evaluation import SymbolMetric, ABLMetric\n",
"from abl.utils import ABLLogger\n",
"from abl.evaluation import SymbolMetric, SemanticsMetric\n",
"from abl.utils import ABLLogger, print_log\n",
"\n",
"from examples.models.nn import SymbolNet\n",
"from datasets.get_hwf import get_hwf"
@@ -26,8 +27,12 @@
"metadata": {},
"outputs": [],
"source": [
"# Initialize logger\n",
"logger = ABLLogger.get_instance(\"abl\")"
"# Initialize logger and print basic information\n",
"print_log(\"Abductive Learning on the HWF example.\", logger=\"current\")\n",
"\n",
"# Retrieve the directory of the Log file and define the directory for saving the model weights.\n",
"log_dir = ABLLogger.get_current_instance().log_dir\n",
"weights_dir = osp.join(log_dir, \"weights\")"
]
},
{
@@ -108,12 +113,12 @@
"# Initialize BasicNN\n",
"# The function of BasicNN is to wrap NN models into the form of an sklearn estimator\n",
"base_model = BasicNN(\n",
" cls,\n",
" criterion,\n",
" optimizer,\n",
" device,\n",
" model=cls,\n",
" criterion=criterion,\n",
" optimizer=optimizer,\n",
" device=device,\n",
" save_interval=1,\n",
" save_dir=logger.save_dir,\n",
" save_dir=weights_dir,\n",
" batch_size=128,\n",
" num_epochs=3,\n",
")"
@@ -146,7 +151,7 @@
"outputs": [],
"source": [
"# Add metric\n",
"metric = [SymbolMetric(prefix=\"hwf\"), ABLMetric(prefix=\"hwf\")]"
"metric_list = [SymbolMetric(prefix=\"hwf\"), SemanticsMetric(prefix=\"hwf\")]"
]
},
{
@@ -164,8 +169,8 @@
"outputs": [],
"source": [
"# Get training and testing data\n",
"train_data = get_hwf(train=True, get_pseudo_label=True)\n",
"test_data = get_hwf(train=False, get_pseudo_label=True)"
"train_data = get_hwf(train=True, get_gt_pseudo_label=True)\n",
"test_data = get_hwf(train=False, get_gt_pseudo_label=True)"
]
},
{
@@ -182,7 +187,7 @@
"metadata": {},
"outputs": [],
"source": [
"bridge = SimpleBridge(model, abducer, metric)"
"bridge = SimpleBridge(model=model, abducer=abducer, metric_list=metric_list)"
]
},
{
@@ -220,7 +225,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
"version": "3.8.16"
},
"orig_nbformat": 4,
"vscode": {


Loading…
Cancel
Save