From 2cf678910b3d5e02a2fff3afb8955925f9687cc0 Mon Sep 17 00:00:00 2001 From: Gao Enhao Date: Wed, 25 Oct 2023 22:05:09 +0800 Subject: [PATCH] [MNT] resolve all comments in hwf_example.ipynb --- abl/utils/logger.py | 14 +++++++++----- examples/hwf/hwf_example.ipynb | 35 +++++++++++++++++++--------------- 2 files changed, 29 insertions(+), 20 deletions(-) diff --git a/abl/utils/logger.py b/abl/utils/logger.py index 348106e..68b5ffc 100644 --- a/abl/utils/logger.py +++ b/abl/utils/logger.py @@ -198,11 +198,11 @@ class ABLLogger(Logger, ManagerMixin): import time local_time = time.strftime("%Y%m%d_%H_%M_%S", time.localtime()) - save_dir = os.path.join("results", local_time) - self.save_dir = save_dir - if not os.path.exists(save_dir): - os.makedirs(save_dir) - log_file = osp.join(save_dir, local_time + ".log") + _log_dir = os.path.join("results", local_time) + self._log_dir = _log_dir + if not os.path.exists(_log_dir): + os.makedirs(_log_dir) + log_file = osp.join(_log_dir, local_time + ".log") file_handler = logging.FileHandler(log_file, file_mode) file_handler.setFormatter( @@ -216,6 +216,10 @@ class ABLLogger(Logger, ManagerMixin): @property def log_file(self): return self._log_file + + @property + def log_dir(self): + return self._log_dir @classmethod def get_current_instance(cls) -> "ABLLogger": diff --git a/examples/hwf/hwf_example.ipynb b/examples/hwf/hwf_example.ipynb index b1ba550..932a25e 100644 --- a/examples/hwf/hwf_example.ipynb +++ b/examples/hwf/hwf_example.ipynb @@ -6,15 +6,16 @@ "metadata": {}, "outputs": [], "source": [ + "import torch\n", "import numpy as np\n", "import torch.nn as nn\n", - "import torch\n", + "import os.path as osp\n", "\n", "from abl.reasoning import ReasonerBase, KBBase\n", "from abl.learning import BasicNN, ABLModel\n", "from abl.bridge import SimpleBridge\n", - "from abl.evaluation import SymbolMetric, ABLMetric\n", - "from abl.utils import ABLLogger\n", + "from abl.evaluation import SymbolMetric, SemanticsMetric\n", + "from abl.utils import ABLLogger, print_log\n", "\n", "from examples.models.nn import SymbolNet\n", "from datasets.get_hwf import get_hwf" @@ -26,8 +27,12 @@ "metadata": {}, "outputs": [], "source": [ - "# Initialize logger\n", - "logger = ABLLogger.get_instance(\"abl\")" + "# Initialize logger and print basic information\n", + "print_log(\"Abductive Learning on the HWF example.\", logger=\"current\")\n", + "\n", + "# Retrieve the directory of the Log file and define the directory for saving the model weights.\n", + "log_dir = ABLLogger.get_current_instance().log_dir\n", + "weights_dir = osp.join(log_dir, \"weights\")" ] }, { @@ -108,12 +113,12 @@ "# Initialize BasicNN\n", "# The function of BasicNN is to wrap NN models into the form of an sklearn estimator\n", "base_model = BasicNN(\n", - " cls,\n", - " criterion,\n", - " optimizer,\n", - " device,\n", + " model=cls,\n", + " criterion=criterion,\n", + " optimizer=optimizer,\n", + " device=device,\n", " save_interval=1,\n", - " save_dir=logger.save_dir,\n", + " save_dir=weights_dir,\n", " batch_size=128,\n", " num_epochs=3,\n", ")" @@ -146,7 +151,7 @@ "outputs": [], "source": [ "# Add metric\n", - "metric = [SymbolMetric(prefix=\"hwf\"), ABLMetric(prefix=\"hwf\")]" + "metric_list = [SymbolMetric(prefix=\"hwf\"), SemanticsMetric(prefix=\"hwf\")]" ] }, { @@ -164,8 +169,8 @@ "outputs": [], "source": [ "# Get training and testing data\n", - "train_data = get_hwf(train=True, get_pseudo_label=True)\n", - "test_data = get_hwf(train=False, get_pseudo_label=True)" + "train_data = get_hwf(train=True, get_gt_pseudo_label=True)\n", + "test_data = get_hwf(train=False, get_gt_pseudo_label=True)" ] }, { @@ -182,7 +187,7 @@ "metadata": {}, "outputs": [], "source": [ - "bridge = SimpleBridge(model, abducer, metric)" + "bridge = SimpleBridge(model=model, abducer=abducer, metric_list=metric_list)" ] }, { @@ -220,7 +225,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.8.16" }, "orig_nbformat": 4, "vscode": {