| @@ -26,6 +26,7 @@ class BaseMetric(metaclass=ABCMeta): | |||
| self, | |||
| prefix: Optional[str] = None, | |||
| ) -> None: | |||
| self.default_prefix = "" | |||
| self.results: List[Any] = [] | |||
| self.prefix = prefix or self.default_prefix | |||
| @@ -204,7 +204,7 @@ class Reasoner: | |||
| dim=dimension, | |||
| constraint=lambda sol: self._constrain_revision_num(sol, max_revision_num), | |||
| ) | |||
| parameter = Parameter(budget=100, intermediate_result=False, autoset=True) | |||
| parameter = Parameter(budget=200, intermediate_result=False, autoset=True) | |||
| solution = Opt.min(objective, parameter) | |||
| return solution | |||
| @@ -1,5 +1,6 @@ | |||
| import os | |||
| from collections import defaultdict | |||
| from typing import Any, List, Optional, Tuple, Union | |||
| import torch | |||
| @@ -41,7 +42,7 @@ class HedBridge(SimpleBridge): | |||
| cls_autoencoder, | |||
| loss_fn, | |||
| optimizer, | |||
| device, | |||
| device=device, | |||
| save_interval=1, | |||
| save_dir=weights_dir, | |||
| num_epochs=10, | |||
| @@ -115,7 +116,7 @@ class HedBridge(SimpleBridge): | |||
| ) | |||
| print_log(log_string, logger="current") | |||
| if true_ratio > 0.95 and false_ratio < 0.1: | |||
| if true_ratio > 0.9 and false_ratio < 0.05: | |||
| return True | |||
| return False | |||
| @@ -215,7 +216,7 @@ class HedBridge(SimpleBridge): | |||
| self.abduce_pseudo_label(sub_data_examples) | |||
| filtered_sub_data_examples = self.filter_empty(sub_data_examples) | |||
| self.pseudo_label_to_idx(filtered_sub_data_examples) | |||
| loss = self.model.train(filtered_sub_data_examples) | |||
| self.model.train(filtered_sub_data_examples) | |||
| if self.check_training_impact(filtered_sub_data_examples, sub_data_examples): | |||
| condition_num += 1 | |||
| @@ -231,6 +232,7 @@ class HedBridge(SimpleBridge): | |||
| seems_good = self.check_rule_quality(rules, val_data, equation_len) | |||
| if seems_good: | |||
| self.reasoner.kb.learned_rules.update({equation_len: rules}) | |||
| self.model.save(save_path=f"./weights/eq_len_{equation_len}.pth") | |||
| break | |||
| else: | |||
| @@ -244,3 +246,19 @@ class HedBridge(SimpleBridge): | |||
| self.model.load(load_path=f"./weights/eq_len_{equation_len - 1}.pth") | |||
| condition_num = 0 | |||
| print_log("Reload Model and retrain", logger="current") | |||
| def test( | |||
| self, | |||
| test_data: Union[ | |||
| ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]] | |||
| ], | |||
| min_len=5, | |||
| max_len=8, | |||
| ) -> None: | |||
| for equation_len in range(min_len, max_len): | |||
| test_data_examples = self.data_preprocess(test_data[1], equation_len) | |||
| print_log(f"Test on true equations with length {equation_len}", logger="current") | |||
| self._valid(test_data_examples) | |||
| test_data_examples = self.data_preprocess(test_data[0], equation_len) | |||
| print_log(f"Test on false equations with length {equation_len}", logger="current") | |||
| self._valid(test_data_examples) | |||
| @@ -0,0 +1,28 @@ | |||
| from typing import Optional | |||
| from abl.reasoning import KBBase | |||
| from abl.data.structures import ListData | |||
| from abl.data.evaluation.base_metric import BaseMetric | |||
| class ConsistencyMetric(BaseMetric): | |||
| def __init__(self, kb: KBBase, prefix: Optional[str] = None) -> None: | |||
| super().__init__(prefix) | |||
| self.kb = kb | |||
| def process(self, data_examples: ListData) -> None: | |||
| pred_pseudo_label = data_examples.pred_pseudo_label | |||
| learned_rules = self.kb.learned_rules | |||
| consistent_num = sum( | |||
| [ | |||
| self.kb.consist_rule(instance, learned_rules[len(instance)]) | |||
| for instance in pred_pseudo_label | |||
| ] | |||
| ) | |||
| self.results.append((consistent_num, len(pred_pseudo_label))) | |||
| def compute_metrics(self) -> dict: | |||
| results = self.results | |||
| metrics = dict() | |||
| metrics["consistency"] = sum(t[0] for t in results) / sum(t[1] for t in results) | |||
| return metrics | |||
| @@ -13,7 +13,7 @@ | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 1, | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| @@ -26,7 +26,8 @@ | |||
| "from examples.models.nn import SymbolNet\n", | |||
| "from abl.learning import ABLModel, BasicNN\n", | |||
| "from examples.hed.reasoning import HedKB, HedReasoner\n", | |||
| "from abl.data.evaluation import ReasoningMetric, SymbolAccuracy\n", | |||
| "from abl.data.evaluation import SymbolAccuracy\n", | |||
| "from examples.hed.consistency_metric import ConsistencyMetric\n", | |||
| "from abl.utils import ABLLogger, print_log\n", | |||
| "from examples.hed.bridge import HedBridge" | |||
| ] | |||
| @@ -47,7 +48,7 @@ | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 2, | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| @@ -65,7 +66,7 @@ | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 3, | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [ | |||
| { | |||
| @@ -119,7 +120,7 @@ | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 4, | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [ | |||
| { | |||
| @@ -240,7 +241,7 @@ | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 5, | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| @@ -254,7 +255,7 @@ | |||
| " cls,\n", | |||
| " loss_fn,\n", | |||
| " optimizer,\n", | |||
| " device,\n", | |||
| " device=device,\n", | |||
| " batch_size=32,\n", | |||
| " num_epochs=1,\n", | |||
| " stop_loss=None,\n", | |||
| @@ -270,7 +271,7 @@ | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 6, | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| @@ -298,7 +299,7 @@ | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 7, | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| @@ -316,7 +317,7 @@ | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 8, | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| @@ -340,11 +341,11 @@ | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 9, | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "metric_list = [SymbolAccuracy(prefix=\"hed\"), ReasoningMetric(kb=kb, prefix=\"hed\")]" | |||
| "metric_list = [ConsistencyMetric(kb=kb)]" | |||
| ] | |||
| }, | |||
| { | |||
| @@ -359,7 +360,7 @@ | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 10, | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| @@ -9,6 +9,7 @@ CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) | |||
| class HedKB(PrologKB): | |||
| def __init__(self, pseudo_label_list=[1, 0, "+", "="], pl_file=os.path.join(CURRENT_DIR, "learn_add.pl")): | |||
| super().__init__(pseudo_label_list, pl_file) | |||
| self.learned_rules = {} | |||
| def consist_rule(self, exs, rules): | |||
| rules = str(rules).replace("'", "") | |||