| @@ -7,7 +7,7 @@ from abl.bridge import SimpleBridge | |||
| from abl.dataset import RegressionDataset | |||
| from abl.evaluation import BaseMetric | |||
| from abl.learning import ABLModel, BasicNN | |||
| from abl.reasoning import ReasonerBase | |||
| from abl.reasoning import Reasoner | |||
| from abl.structures import ListData | |||
| from abl.utils import print_log | |||
| from examples.hed.datasets.get_hed import get_pretrain_data | |||
| @@ -19,7 +19,7 @@ class HEDBridge(SimpleBridge): | |||
| def __init__( | |||
| self, | |||
| model: ABLModel, | |||
| reasoner: ReasonerBase, | |||
| reasoner: Reasoner, | |||
| metric_list: BaseMetric, | |||
| ) -> None: | |||
| super().__init__(model, reasoner, metric_list) | |||
| @@ -92,11 +92,11 @@ class HEDBridge(SimpleBridge): | |||
| def check_training_impact(self, filtered_data_samples, data_samples): | |||
| character_accuracy = self.model.valid(filtered_data_samples) | |||
| revisible_ratio = len(filtered_data_samples.X) / len(data_samples.X) | |||
| print_log( | |||
| f"Revisible ratio is {revisible_ratio:.3f}, Character \ | |||
| accuracy is {character_accuracy:.3f}", | |||
| logger="current", | |||
| log_string = ( | |||
| f"Revisible ratio is {revisible_ratio:.3f}, Character " | |||
| f"accuracy is {character_accuracy:.3f}" | |||
| ) | |||
| print_log(log_string, logger="current") | |||
| if character_accuracy >= 0.9 and revisible_ratio >= 0.9: | |||
| return True | |||
| @@ -109,11 +109,11 @@ class HEDBridge(SimpleBridge): | |||
| true_ratio = self.calc_consistent_ratio(val_X_true, rule) | |||
| false_ratio = self.calc_consistent_ratio(val_X_false, rule) | |||
| print_log( | |||
| f"True consistent ratio is {true_ratio:.3f}, False inconsistent ratio \ | |||
| is {1 - false_ratio:.3f}", | |||
| logger="current", | |||
| log_string = ( | |||
| f"True consistent ratio is {true_ratio:.3f}, False inconsistent ratio " | |||
| f"is {1 - false_ratio:.3f}" | |||
| ) | |||
| print_log(log_string, logger="current") | |||
| if true_ratio > 0.95 and false_ratio < 0.1: | |||
| return True | |||
| @@ -14,12 +14,11 @@ | |||
| "\n", | |||
| "from abl.evaluation import SemanticsMetric, SymbolMetric\n", | |||
| "from abl.learning import ABLModel, BasicNN\n", | |||
| "from abl.reasoning import PrologKB, ReasonerBase\n", | |||
| "from abl.reasoning import PrologKB, Reasoner\n", | |||
| "from abl.utils import ABLLogger, print_log, reform_list\n", | |||
| "from examples.hed.datasets.get_hed import get_hed, split_equation\n", | |||
| "from examples.hed.hed_bridge import HEDBridge\n", | |||
| "from examples.models.nn import SymbolNet\n", | |||
| "from zoopt import Dimension, Objective, Parameter, Opt" | |||
| "from examples.models.nn import SymbolNet" | |||
| ] | |||
| }, | |||
| { | |||
| @@ -68,7 +67,7 @@ | |||
| " return rules\n", | |||
| "\n", | |||
| "\n", | |||
| "class HedReasoner(ReasonerBase):\n", | |||
| "class HedReasoner(Reasoner):\n", | |||
| " def revise_at_idx(self, data_sample):\n", | |||
| " revision_idx = np.where(np.array(data_sample.flatten(\"revision_flag\")) != 0)[0]\n", | |||
| " candidate = self.kb.revise_at_idx(\n", | |||
| @@ -11,7 +11,7 @@ | |||
| "import torch.nn as nn\n", | |||
| "import os.path as osp\n", | |||
| "\n", | |||
| "from abl.reasoning import ReasonerBase, KBBase\n", | |||
| "from abl.reasoning import Reasoner, KBBase\n", | |||
| "from abl.learning import BasicNN, ABLModel\n", | |||
| "from abl.bridge import SimpleBridge\n", | |||
| "from abl.evaluation import SymbolMetric, SemanticsMetric\n", | |||
| @@ -75,7 +75,7 @@ | |||
| " max_err=1e-10,\n", | |||
| " use_cache=False,\n", | |||
| ")\n", | |||
| "reasoner = ReasonerBase(kb, dist_func=\"confidence\")" | |||
| "reasoner = Reasoner(kb, dist_func=\"confidence\")" | |||
| ] | |||
| }, | |||
| { | |||
| @@ -220,7 +220,7 @@ | |||
| "name": "python", | |||
| "nbconvert_exporter": "python", | |||
| "pygments_lexer": "ipython3", | |||
| "version": "3.8.16" | |||
| "version": "3.8.18" | |||
| }, | |||
| "orig_nbformat": 4, | |||
| "vscode": { | |||
| @@ -14,7 +14,7 @@ | |||
| "from abl.bridge import SimpleBridge\n", | |||
| "from abl.evaluation import SemanticsMetric, SymbolMetric\n", | |||
| "from abl.learning import ABLModel, BasicNN\n", | |||
| "from abl.reasoning import KBBase, ReasonerBase\n", | |||
| "from abl.reasoning import KBBase, Reasoner\n", | |||
| "from abl.utils import ABLLogger, print_log\n", | |||
| "from examples.mnist_add.datasets.get_mnist_add import get_mnist_add\n", | |||
| "from examples.models.nn import LeNet5" | |||
| @@ -109,7 +109,7 @@ | |||
| "\n", | |||
| "\n", | |||
| "kb = AddKB(pseudo_label_list=list(range(10)))\n", | |||
| "reasoner = ReasonerBase(kb, dist_func=\"confidence\")" | |||
| "reasoner = Reasoner(kb, dist_func=\"confidence\")" | |||
| ] | |||
| }, | |||
| { | |||