From 70678bf968d46c9150e95b6b2b43ee7ce81b3ede Mon Sep 17 00:00:00 2001 From: Gao Enhao Date: Sun, 10 Dec 2023 01:12:56 +0800 Subject: [PATCH] [MNT] change 'ReasonerBase' to 'Reasoner' --- examples/hed/hed_bridge.py | 20 ++++++++++---------- examples/hed/hed_example.ipynb | 7 +++---- examples/hwf/hwf_example.ipynb | 6 +++--- examples/mnist_add/mnist_add_example.ipynb | 4 ++-- 4 files changed, 18 insertions(+), 19 deletions(-) diff --git a/examples/hed/hed_bridge.py b/examples/hed/hed_bridge.py index 13f5946..0b08ec8 100644 --- a/examples/hed/hed_bridge.py +++ b/examples/hed/hed_bridge.py @@ -7,7 +7,7 @@ from abl.bridge import SimpleBridge from abl.dataset import RegressionDataset from abl.evaluation import BaseMetric from abl.learning import ABLModel, BasicNN -from abl.reasoning import ReasonerBase +from abl.reasoning import Reasoner from abl.structures import ListData from abl.utils import print_log from examples.hed.datasets.get_hed import get_pretrain_data @@ -19,7 +19,7 @@ class HEDBridge(SimpleBridge): def __init__( self, model: ABLModel, - reasoner: ReasonerBase, + reasoner: Reasoner, metric_list: BaseMetric, ) -> None: super().__init__(model, reasoner, metric_list) @@ -92,11 +92,11 @@ class HEDBridge(SimpleBridge): def check_training_impact(self, filtered_data_samples, data_samples): character_accuracy = self.model.valid(filtered_data_samples) revisible_ratio = len(filtered_data_samples.X) / len(data_samples.X) - print_log( - f"Revisible ratio is {revisible_ratio:.3f}, Character \ - accuracy is {character_accuracy:.3f}", - logger="current", + log_string = ( + f"Revisible ratio is {revisible_ratio:.3f}, Character " + f"accuracy is {character_accuracy:.3f}" ) + print_log(log_string, logger="current") if character_accuracy >= 0.9 and revisible_ratio >= 0.9: return True @@ -109,11 +109,11 @@ class HEDBridge(SimpleBridge): true_ratio = self.calc_consistent_ratio(val_X_true, rule) false_ratio = self.calc_consistent_ratio(val_X_false, rule) - print_log( - f"True consistent ratio is {true_ratio:.3f}, False inconsistent ratio \ - is {1 - false_ratio:.3f}", - logger="current", + log_string = ( + f"True consistent ratio is {true_ratio:.3f}, False inconsistent ratio " + f"is {1 - false_ratio:.3f}" ) + print_log(log_string, logger="current") if true_ratio > 0.95 and false_ratio < 0.1: return True diff --git a/examples/hed/hed_example.ipynb b/examples/hed/hed_example.ipynb index e479b52..0e9c9d4 100644 --- a/examples/hed/hed_example.ipynb +++ b/examples/hed/hed_example.ipynb @@ -14,12 +14,11 @@ "\n", "from abl.evaluation import SemanticsMetric, SymbolMetric\n", "from abl.learning import ABLModel, BasicNN\n", - "from abl.reasoning import PrologKB, ReasonerBase\n", + "from abl.reasoning import PrologKB, Reasoner\n", "from abl.utils import ABLLogger, print_log, reform_list\n", "from examples.hed.datasets.get_hed import get_hed, split_equation\n", "from examples.hed.hed_bridge import HEDBridge\n", - "from examples.models.nn import SymbolNet\n", - "from zoopt import Dimension, Objective, Parameter, Opt" + "from examples.models.nn import SymbolNet" ] }, { @@ -68,7 +67,7 @@ " return rules\n", "\n", "\n", - "class HedReasoner(ReasonerBase):\n", + "class HedReasoner(Reasoner):\n", " def revise_at_idx(self, data_sample):\n", " revision_idx = np.where(np.array(data_sample.flatten(\"revision_flag\")) != 0)[0]\n", " candidate = self.kb.revise_at_idx(\n", diff --git a/examples/hwf/hwf_example.ipynb b/examples/hwf/hwf_example.ipynb index 485087c..bae8d0d 100644 --- a/examples/hwf/hwf_example.ipynb +++ b/examples/hwf/hwf_example.ipynb @@ -11,7 +11,7 @@ "import torch.nn as nn\n", "import os.path as osp\n", "\n", - "from abl.reasoning import ReasonerBase, KBBase\n", + "from abl.reasoning import Reasoner, KBBase\n", "from abl.learning import BasicNN, ABLModel\n", "from abl.bridge import SimpleBridge\n", "from abl.evaluation import SymbolMetric, SemanticsMetric\n", @@ -75,7 +75,7 @@ " max_err=1e-10,\n", " use_cache=False,\n", ")\n", - "reasoner = ReasonerBase(kb, dist_func=\"confidence\")" + "reasoner = Reasoner(kb, dist_func=\"confidence\")" ] }, { @@ -220,7 +220,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.16" + "version": "3.8.18" }, "orig_nbformat": 4, "vscode": { diff --git a/examples/mnist_add/mnist_add_example.ipynb b/examples/mnist_add/mnist_add_example.ipynb index eb4e4d3..980094f 100644 --- a/examples/mnist_add/mnist_add_example.ipynb +++ b/examples/mnist_add/mnist_add_example.ipynb @@ -14,7 +14,7 @@ "from abl.bridge import SimpleBridge\n", "from abl.evaluation import SemanticsMetric, SymbolMetric\n", "from abl.learning import ABLModel, BasicNN\n", - "from abl.reasoning import KBBase, ReasonerBase\n", + "from abl.reasoning import KBBase, Reasoner\n", "from abl.utils import ABLLogger, print_log\n", "from examples.mnist_add.datasets.get_mnist_add import get_mnist_add\n", "from examples.models.nn import LeNet5" @@ -109,7 +109,7 @@ "\n", "\n", "kb = AddKB(pseudo_label_list=list(range(10)))\n", - "reasoner = ReasonerBase(kb, dist_func=\"confidence\")" + "reasoner = Reasoner(kb, dist_func=\"confidence\")" ] }, {