Browse Source

[MNT] change 'ReasonerBase' to 'Reasoner'

pull/1/head
Gao Enhao 2 years ago
parent
commit
70678bf968
4 changed files with 18 additions and 19 deletions
  1. +10
    -10
      examples/hed/hed_bridge.py
  2. +3
    -4
      examples/hed/hed_example.ipynb
  3. +3
    -3
      examples/hwf/hwf_example.ipynb
  4. +2
    -2
      examples/mnist_add/mnist_add_example.ipynb

+ 10
- 10
examples/hed/hed_bridge.py View File

@@ -7,7 +7,7 @@ from abl.bridge import SimpleBridge
from abl.dataset import RegressionDataset
from abl.evaluation import BaseMetric
from abl.learning import ABLModel, BasicNN
from abl.reasoning import ReasonerBase
from abl.reasoning import Reasoner
from abl.structures import ListData
from abl.utils import print_log
from examples.hed.datasets.get_hed import get_pretrain_data
@@ -19,7 +19,7 @@ class HEDBridge(SimpleBridge):
def __init__(
self,
model: ABLModel,
reasoner: ReasonerBase,
reasoner: Reasoner,
metric_list: BaseMetric,
) -> None:
super().__init__(model, reasoner, metric_list)
@@ -92,11 +92,11 @@ class HEDBridge(SimpleBridge):
def check_training_impact(self, filtered_data_samples, data_samples):
character_accuracy = self.model.valid(filtered_data_samples)
revisible_ratio = len(filtered_data_samples.X) / len(data_samples.X)
print_log(
f"Revisible ratio is {revisible_ratio:.3f}, Character \
accuracy is {character_accuracy:.3f}",
logger="current",
log_string = (
f"Revisible ratio is {revisible_ratio:.3f}, Character "
f"accuracy is {character_accuracy:.3f}"
)
print_log(log_string, logger="current")

if character_accuracy >= 0.9 and revisible_ratio >= 0.9:
return True
@@ -109,11 +109,11 @@ class HEDBridge(SimpleBridge):
true_ratio = self.calc_consistent_ratio(val_X_true, rule)
false_ratio = self.calc_consistent_ratio(val_X_false, rule)

print_log(
f"True consistent ratio is {true_ratio:.3f}, False inconsistent ratio \
is {1 - false_ratio:.3f}",
logger="current",
log_string = (
f"True consistent ratio is {true_ratio:.3f}, False inconsistent ratio "
f"is {1 - false_ratio:.3f}"
)
print_log(log_string, logger="current")

if true_ratio > 0.95 and false_ratio < 0.1:
return True


+ 3
- 4
examples/hed/hed_example.ipynb View File

@@ -14,12 +14,11 @@
"\n",
"from abl.evaluation import SemanticsMetric, SymbolMetric\n",
"from abl.learning import ABLModel, BasicNN\n",
"from abl.reasoning import PrologKB, ReasonerBase\n",
"from abl.reasoning import PrologKB, Reasoner\n",
"from abl.utils import ABLLogger, print_log, reform_list\n",
"from examples.hed.datasets.get_hed import get_hed, split_equation\n",
"from examples.hed.hed_bridge import HEDBridge\n",
"from examples.models.nn import SymbolNet\n",
"from zoopt import Dimension, Objective, Parameter, Opt"
"from examples.models.nn import SymbolNet"
]
},
{
@@ -68,7 +67,7 @@
" return rules\n",
"\n",
"\n",
"class HedReasoner(ReasonerBase):\n",
"class HedReasoner(Reasoner):\n",
" def revise_at_idx(self, data_sample):\n",
" revision_idx = np.where(np.array(data_sample.flatten(\"revision_flag\")) != 0)[0]\n",
" candidate = self.kb.revise_at_idx(\n",


+ 3
- 3
examples/hwf/hwf_example.ipynb View File

@@ -11,7 +11,7 @@
"import torch.nn as nn\n",
"import os.path as osp\n",
"\n",
"from abl.reasoning import ReasonerBase, KBBase\n",
"from abl.reasoning import Reasoner, KBBase\n",
"from abl.learning import BasicNN, ABLModel\n",
"from abl.bridge import SimpleBridge\n",
"from abl.evaluation import SymbolMetric, SemanticsMetric\n",
@@ -75,7 +75,7 @@
" max_err=1e-10,\n",
" use_cache=False,\n",
")\n",
"reasoner = ReasonerBase(kb, dist_func=\"confidence\")"
"reasoner = Reasoner(kb, dist_func=\"confidence\")"
]
},
{
@@ -220,7 +220,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
"version": "3.8.18"
},
"orig_nbformat": 4,
"vscode": {


+ 2
- 2
examples/mnist_add/mnist_add_example.ipynb View File

@@ -14,7 +14,7 @@
"from abl.bridge import SimpleBridge\n",
"from abl.evaluation import SemanticsMetric, SymbolMetric\n",
"from abl.learning import ABLModel, BasicNN\n",
"from abl.reasoning import KBBase, ReasonerBase\n",
"from abl.reasoning import KBBase, Reasoner\n",
"from abl.utils import ABLLogger, print_log\n",
"from examples.mnist_add.datasets.get_mnist_add import get_mnist_add\n",
"from examples.models.nn import LeNet5"
@@ -109,7 +109,7 @@
"\n",
"\n",
"kb = AddKB(pseudo_label_list=list(range(10)))\n",
"reasoner = ReasonerBase(kb, dist_func=\"confidence\")"
"reasoner = Reasoner(kb, dist_func=\"confidence\")"
]
},
{


Loading…
Cancel
Save