Browse Source

[ENH] add search engine

ab_data
Gao Enhao 2 years ago
parent
commit
2951e5fe5a
8 changed files with 535 additions and 523 deletions
  1. +1
    -0
      abl/reasoning/__init__.py
  2. +43
    -503
      abl/reasoning/reasoner.py
  3. +2
    -20
      abl/reasoning/search_based_kb.py
  4. +3
    -0
      abl/reasoning/search_engine/__init__.py
  5. +13
    -0
      abl/reasoning/search_engine/base_search_engine.py
  6. +28
    -0
      abl/reasoning/search_engine/bfs.py
  7. +42
    -0
      abl/reasoning/search_engine/zoopt.py
  8. +403
    -0
      tests/test_reasoning.py

+ 1
- 0
abl/reasoning/__init__.py View File

@@ -3,3 +3,4 @@ from .ground_kb import GroundKB
from .prolog_based_kb import PrologBasedKB
from .reasoner import ReasonerBase
from .search_based_kb import SearchBasedKB
from .search_engine import BaseSearchEngine, BFS, Zoopt

+ 43
- 503
abl/reasoning/reasoner.py View File

@@ -1,11 +1,11 @@
from typing import Any, List, Mapping, Tuple, Union
from typing import Any, List, Mapping

import numpy as np
from zoopt import Dimension, Objective, Opt, Parameter, Solution

from ..structures import ListData
from ..utils.utils import calculate_revision_num, confidence_dist, hamming_dist, reform_idx
from ..utils.utils import calculate_revision_num, confidence_dist, hamming_dist
from .base_kb import BaseKB
from .search_engine import BaseSearchEngine, BFS


class ReasonerBase:
@@ -14,7 +14,7 @@ class ReasonerBase:
kb: BaseKB,
dist_func: str = "confidence",
mapping: Mapping = None,
use_zoopt: bool = False,
search_engine: BaseSearchEngine = None,
):
"""
Base class for all reasoner in the ABL system.
@@ -36,12 +36,14 @@ class ReasonerBase:
If the specified distance function is neither "hamming" nor "confidence".
"""

if not isinstance(kb, BaseKB):
raise ValueError("The kb should be of type BaseKB.")
self.kb = kb

if dist_func not in ["hamming", "confidence"]:
raise NotImplementedError(f"The distance function '{dist_func}' is not implemented.")

self.kb = kb
self.dist_func = dist_func
self.use_zoopt = use_zoopt

if mapping is None:
self.mapping = {index: label for index, label in enumerate(self.kb.pseudo_label_list)}
else:
@@ -56,10 +58,17 @@ class ReasonerBase:
raise ValueError("All values in the mapping must be in the pseudo_label_list")

self.mapping = mapping

self.remapping = dict(zip(self.mapping.values(), self.mapping.keys()))

def _get_cost_list(self, data_sample: ListData, candidates: List[List[Any]]):
if search_engine is None:
self.search_engine = BFS()
else:
if not isinstance(search_engine, BaseSearchEngine):
raise ValueError("The search_engine should be of type BaseSearchEngine.")
else:
self.search_engine = search_engine

def _get_dist_list(self, data_sample: ListData, candidates: List[List[Any]]):
"""
Get the list of costs between each pseudo label and candidate.

@@ -84,7 +93,7 @@ class ReasonerBase:
candidates = [[self.remapping[x] for x in c] for c in candidates]
return confidence_dist(data_sample["pred_prob"][0], candidates)

def _get_one_candidate(self, data_sample: ListData, candidates: List[List[Any]]):
def select_one_candidate(self, data_sample: ListData, candidates: List[List[Any]]):
"""
Get one candidate. If multiple candidates exist, return the one with minimum cost.

@@ -108,91 +117,10 @@ class ReasonerBase:
elif len(candidates) == 1:
return candidates[0]
else:
cost_array = self._get_cost_list(data_sample, candidates)
cost_array = self._get_dist_list(data_sample, candidates)
candidate = candidates[np.argmin(cost_array)]
return candidate

def zoopt_revision_score(self, data_sample: ListData, solution: Solution):
"""
Get the revision score for a single solution.

Parameters
----------
pred_pseudo_label : list
List of predicted pseudo labels.
pred_prob : list
List of probabilities for predicted results.
y : any
Ground truth for the predicted results.
solution : array-like
Solution to evaluate.

Returns
-------
float
The revision score for the given solution.
"""
revision_idx = np.where(solution.get_x() != 0)[0]
candidates = self.revise_at_idx(data_sample, revision_idx)
if len(candidates) > 0:
return np.min(self._get_cost_list(data_sample, candidates))
else:
return data_sample["symbol_num"]

def _constrain_revision_num(self, solution: Solution, max_revision_num: int):
x = solution.get_x()
return max_revision_num - x.sum()

def zoopt_get_solution(self, data_sample: ListData, max_revision_num: int):
"""Get the optimal solution using the Zoopt library.

Parameters
----------
pred_pseudo_label : list
List of predicted pseudo labels.
pred_prob : list
List of probabilities for predicted results.
y : any
Ground truth for the predicted results.
max_revision_num : int
Maximum number of revisions to use.

Returns
-------
array-like
The optimal solution, i.e., where to revise predict pseudo label.
"""
symbol_num = data_sample["symbol_num"]
dimension = Dimension(size=symbol_num, regs=[[0, 1]] * symbol_num, tys=[False] * symbol_num)
objective = Objective(
lambda solution: self.zoopt_revision_score(data_sample, solution),
dim=dimension,
constraint=lambda solution: self._constrain_revision_num(solution, max_revision_num),
)
parameter = Parameter(budget=100, intermediate_result=False, autoset=True)
solution = Opt.min(objective, parameter).get_x()
return solution

def revise_at_idx(self, data_sample: ListData, revision_idx: Union[List, Tuple, np.ndarray]):
"""
Revise the pseudo label according to the given indices.

Parameters
----------
pred_pseudo_label : list
List of predicted pseudo labels.
y : any
Ground truth for the predicted results.
revision_idx : array-like
Indices of the revisions to retrieve.

Returns
-------
list
The revisions according to the given indices.
"""
return self.kb.revise_at_idx(data_sample, revision_idx)

def abduce(
self,
data_sample: ListData,
@@ -223,19 +151,35 @@ class ReasonerBase:
"""
symbol_num = data_sample.elements_num("pred_pseudo_label")
max_revision_num = calculate_revision_num(max_revision, symbol_num)

data_sample.set_metainfo(dict(symbol_num=symbol_num))

if self.use_zoopt:
solution = self.zoopt_get_solution(data_sample, max_revision_num)
revision_idx = np.where(solution != 0)[0]
candidates = self.revise_at_idx(data_sample, revision_idx)
else:
if hasattr(self.kb, "abduce_candidates"):
candidates = self.kb.abduce_candidates(
data_sample, max_revision_num, require_more_revision
)
elif hasattr(self.kb, "revise_at_idx"):
candidates = []
gen = self.search_engine.generator(
data_sample,
max_revision_num=max_revision_num,
require_more_revision=require_more_revision,
)
send_signal = True
for revision_idx in gen:
candidates.extend(self.kb.revise_at_idx(data_sample, revision_idx))
if len(candidates) > 0 and send_signal:
try:
revision_idx = gen.send("success")
candidates.extend(self.kb.revise_at_idx(data_sample, revision_idx))
send_signal = False
except StopIteration:
break
else:
raise NotImplementedError(
"The kb should either implement abduce_candidates or revise_at_idx."
)

candidate = self._get_one_candidate(data_sample, candidates)
candidate = self.select_one_candidate(data_sample, candidates)
return candidate

def batch_abduce(
@@ -285,407 +229,3 @@ class ReasonerBase:
# with Pool(processes=os.cpu_count()) as pool:
# results = pool.map(self._batch_abduce_helper, [(z, prob, y, max_revision, require_more_revision) for z, prob, y in zip(Z['cls'], Z['prob'], Y)])
# return results


if __name__ == "__main__":
from abl.reasoning.base_kb import BaseKB, GroundKB, PrologBasedKB

prob1 = [
[
[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
]
]

prob2 = [
[
[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
]
]

class add_KB(BaseKB):
def __init__(self, pseudo_label_list=list(range(10)), use_cache=True):
super().__init__(pseudo_label_list, use_cache=use_cache)

def logic_forward(self, nums):
return sum(nums)

class add_GroundKB(GroundKB):
def __init__(self, pseudo_label_list=list(range(10)), GKB_len_list=[2]):
super().__init__(pseudo_label_list, GKB_len_list)

def logic_forward(self, nums):
return sum(nums)

def test_add(reasoner):
res = reasoner.batch_abduce(prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0)
print(res)
res = reasoner.batch_abduce(prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0)
print(res)
res = reasoner.batch_abduce(prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0)
print(res)
res = reasoner.batch_abduce(prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0)
print(res)
res = reasoner.batch_abduce(prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0)
print(res)
print()

print("add_KB with GKB:")
kb = add_GroundKB()
reasoner = ReasonerBase(kb, "confidence")
test_add(reasoner)

print("add_KB without GKB:")
kb = add_KB()
reasoner = ReasonerBase(kb, "confidence")
test_add(reasoner)

print("add_KB without GKB, no cache")
kb = add_KB(use_cache=False)
reasoner = ReasonerBase(kb, "confidence")
test_add(reasoner)

print("PrologBasedKB with add.pl:")
kb = PrologBasedKB(
pseudo_label_list=list(range(10)), pl_file="examples/mnist_add/datasets/add.pl"
)
reasoner = ReasonerBase(kb, "confidence")
test_add(reasoner)

print("PrologBasedKB with add.pl using zoopt:")
kb = PrologBasedKB(
pseudo_label_list=list(range(10)),
pl_file="examples/mnist_add/datasets/add.pl",
)
reasoner = ReasonerBase(kb, "confidence", use_zoopt=True)
test_add(reasoner)

print("add_KB with multiple inputs at once:")
multiple_prob = [
[
[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
],
[
[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
],
]

kb = add_KB()
reasoner = ReasonerBase(kb, "confidence")
res = reasoner.batch_abduce(
multiple_prob,
[[1, 1], [1, 2]],
[4, 8],
max_revision=2,
require_more_revision=0,
)
print(res)
res = reasoner.batch_abduce(
multiple_prob,
[[1, 1], [1, 2]],
[4, 8],
max_revision=2,
require_more_revision=1,
)
print(res)
print()

class HWF_KB(BaseKB):
def __init__(
self,
pseudo_label_list=[
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9",
"+",
"-",
"times",
"div",
],
max_err=1e-3,
):
super().__init__(pseudo_label_list, max_err)

def _valid_candidate(self, formula):
if len(formula) % 2 == 0:
return False
for i in range(len(formula)):
if i % 2 == 0 and formula[i] not in [
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9",
]:
return False
if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]:
return False
return True

def logic_forward(self, formula):
if not self._valid_candidate(formula):
return np.inf
mapping = {str(i): str(i) for i in range(1, 10)}
mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"})
formula = [mapping[f] for f in formula]
return eval("".join(formula))

class HWF_GroundKB(GroundKB):
def __init__(
self,
pseudo_label_list=[
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9",
"+",
"-",
"times",
"div",
],
GKB_len_list=[1, 3, 5, 7],
max_err=1e-3,
):
super().__init__(pseudo_label_list, GKB_len_list, max_err)

def _valid_candidate(self, formula):
if len(formula) % 2 == 0:
return False
for i in range(len(formula)):
if i % 2 == 0 and formula[i] not in [
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9",
]:
return False
if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]:
return False
return True

def logic_forward(self, formula):
if not self._valid_candidate(formula):
return np.inf
mapping = {str(i): str(i) for i in range(1, 10)}
mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"})
formula = [mapping[f] for f in formula]
return eval("".join(formula))

def test_hwf(reasoner):
res = reasoner.batch_abduce(
[None],
[["5", "+", "2"]],
[3],
max_revision=2,
require_more_revision=0,
)
print(res)
res = reasoner.batch_abduce(
[None],
[["5", "+", "9"]],
[65],
max_revision=3,
require_more_revision=0,
)
print(res)
res = reasoner.batch_abduce(
[None],
[["5", "8", "8", "8", "8"]],
[3.17],
max_revision=5,
require_more_revision=3,
)
print(res)
print()

def test_hwf_multiple(reasoner, max_revisions):
res = reasoner.batch_abduce(
[None, None],
[["5", "+", "2"], ["5", "+", "9"]],
[3, 64],
max_revision=max_revisions[0],
require_more_revision=0,
)
print(res)
res = reasoner.batch_abduce(
[None, None],
[["5", "+", "2"], ["5", "+", "9"]],
[3, 64],
max_revision=max_revisions[1],
require_more_revision=0,
)
print(res)
res = reasoner.batch_abduce(
[None, None],
[["5", "+", "2"], ["5", "+", "9"]],
[3, 65],
max_revision=max_revisions[2],
require_more_revision=0,
)
print(res)
print()

print("HWF_KB with GKB, max_err=0.1")
kb = HWF_GroundKB(GKB_len_list=[1, 3, 5], max_err=0.1)
reasoner = ReasonerBase(kb, "hamming")
test_hwf(reasoner)

print("HWF_KB without GKB, max_err=0.1")
kb = HWF_KB(max_err=0.1)
reasoner = ReasonerBase(kb, "hamming")
test_hwf(reasoner)

print("HWF_KB with GKB, max_err=1")
kb = HWF_GroundKB(GKB_len_list=[1, 3, 5], max_err=1)
reasoner = ReasonerBase(kb, "hamming")
test_hwf(reasoner)

print("HWF_KB without GKB, max_err=1")
kb = HWF_KB(max_err=1)
reasoner = ReasonerBase(kb, "hamming")
test_hwf(reasoner)

print("HWF_KB with multiple inputs at once:")
kb = HWF_KB(max_err=0.1)
reasoner = ReasonerBase(kb, "hamming")
test_hwf_multiple(reasoner, max_revisions=[1, 3, 3])

print("max_revision is float")
test_hwf_multiple(reasoner, max_revisions=[0.5, 0.9, 0.9])

class HED_prolog_KB(PrologBasedKB):
def __init__(self, pseudo_label_list, pl_file):
super().__init__(pseudo_label_list, pl_file)

def consist_rule(self, exs, rules):
rules = str(rules).replace("'", "")
pl_query = "eval_inst_feature(%s, %s)." % (exs, rules)
return len(list(self.prolog.query(pl_query))) != 0

def abduce_rules(self, pred_res):
pl_query = "consistent_inst_feature(%s, X)." % pred_res
prolog_result = list(self.prolog.query(pl_query))
if len(prolog_result) == 0:
return None
prolog_rules = prolog_result[0]["X"]
rules = [rule.value for rule in prolog_rules]
return rules

class HED_Reasoner(ReasonerBase):
def __init__(self, kb, dist_func="hamming"):
super().__init__(kb, dist_func, use_zoopt=True)

def _revise_at_idxs(self, pred_res, y, all_revision_flag, idxs):
pred = []
k = []
revision_flag = []
for idx in idxs:
pred.append(pred_res[idx])
k.append(y[idx])
revision_flag += list(all_revision_flag[idx])
revision_idx = np.where(np.array(revision_flag) != 0)[0]
candidate = self.revise_at_idx(pred, k, revision_idx)
return candidate

def zoopt_revision_score(self, symbol_num, pred_res, pred_prob, y, sol):
all_revision_flag = reform_idx(sol.get_x(), pred_res)
lefted_idxs = [i for i in range(len(pred_res))]
candidate_size = []
while lefted_idxs:
idxs = []
idxs.append(lefted_idxs.pop(0))
max_candidate_idxs = []
found = False
for idx in range(-1, len(pred_res)):
if (not idx in idxs) and (idx >= 0):
idxs.append(idx)
candidate = self._revise_at_idxs(pred_res, y, all_revision_flag, idxs)
if len(candidate) == 0:
if len(idxs) > 1:
idxs.pop()
else:
if len(idxs) > len(max_candidate_idxs):
found = True
max_candidate_idxs = idxs.copy()
removed = [i for i in lefted_idxs if i in max_candidate_idxs]
if found:
candidate_size.append(len(removed) + 1)
lefted_idxs = [i for i in lefted_idxs if i not in max_candidate_idxs]
candidate_size.sort()
score = 0
import math

for i in range(0, len(candidate_size)):
score -= math.exp(-i) * candidate_size[i]
return score

def abduce_rules(self, pred_res):
return self.kb.abduce_rules(pred_res)

kb = HED_prolog_KB(
pseudo_label_list=[1, 0, "+", "="],
pl_file="examples/hed/datasets/learn_add.pl",
)
reasoner = HED_Reasoner(kb)
consist_exs = [
[1, 1, "+", 0, "=", 1, 1],
[1, "+", 1, "=", 1, 0],
[0, "+", 0, "=", 0],
]
inconsist_exs1 = [
[1, 1, "+", 0, "=", 1, 1],
[1, "+", 1, "=", 1, 0],
[0, "+", 0, "=", 0],
[0, "+", 0, "=", 1],
]
inconsist_exs2 = [[1, "+", 0, "=", 0], [1, "=", 1, "=", 0], [0, "=", 0, "=", 1, 1]]
rules = ["my_op([0], [0], [0])", "my_op([1], [1], [1, 0])"]

print("HED_kb logic forward")
print(kb.logic_forward(consist_exs))
print(kb.logic_forward(inconsist_exs1), kb.logic_forward(inconsist_exs2))
print()
print("HED_kb consist rule")
print(kb.consist_rule([1, "+", 1, "=", 1, 0], rules))
print(kb.consist_rule([1, "+", 1, "=", 1, 1], rules))
print()

print("HED_Reasoner abduce")
res = reasoner.abduce([[[None]]] * len(consist_exs), consist_exs, [None] * len(consist_exs))
print(res)
res = reasoner.abduce(
[[[None]]] * len(inconsist_exs1), inconsist_exs1, [None] * len(inconsist_exs1)
)
print(res)
res = reasoner.abduce(
[[[None]]] * len(inconsist_exs2), inconsist_exs2, [None] * len(inconsist_exs2)
)
print(res)
print()

print("HED_Reasoner abduce rules")
abduced_rules = reasoner.abduce_rules(consist_exs)
print(abduced_rules)

+ 2
- 20
abl/reasoning/search_based_kb.py View File

@@ -4,28 +4,10 @@ from typing import Any, Callable, Generator, List, Optional, Tuple, Union

import numpy

from abl.structures import ListData

from ..structures import ListData
from ..utils import Cache
from .base_kb import BaseKB


def incremental_search_strategy(
data_sample: ListData, max_revision_num: int, require_more_revision: int = 0
):
symbol_num = data_sample["symbol_num"]
max_revision_num = min(max_revision_num, symbol_num)
real_end = max_revision_num
for revision_num in range(max_revision_num + 1):
if revision_num > real_end:
break

revision_idx_tuple = combinations(range(symbol_num), revision_num)
for revision_idx in revision_idx_tuple:
received = yield revision_idx
if received == "success":
real_end = min(symbol_num, revision_num + require_more_revision)
from .search_engine import incremental_search_strategy


class SearchBasedKB(BaseKB, ABC):
@@ -35,7 +17,7 @@ class SearchBasedKB(BaseKB, ABC):
search_strategy: Callable[[ListData, int, int], Generator] = incremental_search_strategy,
use_cache: bool = True,
cache_file: Optional[str] = None,
cache_size: int = 4096
cache_size: int = 4096,
) -> None:
super().__init__(pseudo_label_list)
self.search_strategy = search_strategy


+ 3
- 0
abl/reasoning/search_engine/__init__.py View File

@@ -0,0 +1,3 @@
from .base_search_engine import BaseSearchEngine
from .bfs import BFS
from .zoopt import Zoopt

+ 13
- 0
abl/reasoning/search_engine/base_search_engine.py View File

@@ -0,0 +1,13 @@
from abc import ABC, abstractmethod
from typing import List, Tuple, Union

import numpy

from ...structures import ListData


class BaseSearchEngine(ABC):
@abstractmethod
def generator(data_sample: ListData) -> Union[List, Tuple, numpy.ndarray]:
"""Placeholder for the generator of revision_idx."""
pass

+ 28
- 0
abl/reasoning/search_engine/bfs.py View File

@@ -0,0 +1,28 @@
from itertools import combinations
from typing import List, Tuple, Union

import numpy

from ...structures import ListData
from .base_search_engine import BaseSearchEngine


class BFS(BaseSearchEngine):
def __init__(self) -> None:
pass

def generator(
data_sample: ListData, max_revision_num: int, require_more_revision: int = 0
) -> Union[List, Tuple, numpy.ndarray]:
symbol_num = data_sample["symbol_num"]
max_revision_num = min(max_revision_num, symbol_num)
real_end = max_revision_num
for revision_num in range(max_revision_num + 1):
if revision_num > real_end:
break

revision_idx_tuple = combinations(range(symbol_num), revision_num)
for revision_idx in revision_idx_tuple:
received = yield revision_idx
if received == "success":
real_end = min(symbol_num, revision_num + require_more_revision)

+ 42
- 0
abl/reasoning/search_engine/zoopt.py View File

@@ -0,0 +1,42 @@
from typing import List, Tuple, Union

import numpy as np
from zoopt import Dimension, Objective, Opt, Parameter, Solution

from ...structures import ListData
from ..reasoner import ReasonerBase
from ..search_based_kb import SearchBasedKB
from .base_search_engine import BaseSearchEngine


class Zoopt(BaseSearchEngine):
def __init__(self, reasoner: ReasonerBase, kb: SearchBasedKB) -> None:
self.reasoner = reasoner
self.kb = kb

def score_func(self, data_sample: ListData, solution: Solution):
revision_idx = np.where(solution.get_x() != 0)[0]
candidates = self.kb.revise_at_idx(data_sample, revision_idx)
if len(candidates) > 0:
return np.min(self.reasoner._get_dist_list(data_sample, candidates))
else:
return data_sample["symbol_num"]

@staticmethod
def constraint(solution: Solution, max_revision_num: int):
x = solution.get_x()
return max_revision_num - x.sum()

def generator(
self, data_sample: ListData, max_revision_num: int, require_more_revision: int = 0
) -> Union[List, Tuple, np.ndarray]:
symbol_num = data_sample["symbol_num"]
dimension = Dimension(size=symbol_num, regs=[[0, 1]] * symbol_num, tys=[False] * symbol_num)
objective = Objective(
lambda solution: self.score_func(self, data_sample, solution),
dim=dimension,
constraint=lambda solution: self.constraint(solution, max_revision_num),
)
parameter = Parameter(budget=100, intermediate_result=False, autoset=True)
solution = Opt.min(objective, parameter).get_x()
yield solution

+ 403
- 0
tests/test_reasoning.py View File

@@ -0,0 +1,403 @@

from abl.reasoning import ReasonerBase, BaseKB, GroundKB, PrologBasedKB

if __name__ == "__main__":
prob1 = [
[
[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
]
]

prob2 = [
[
[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
]
]

class add_KB(BaseKB):
def __init__(self, pseudo_label_list=list(range(10)), use_cache=True):
super().__init__(pseudo_label_list, use_cache=use_cache)

def logic_forward(self, nums):
return sum(nums)

class add_GroundKB(GroundKB):
def __init__(self, pseudo_label_list=list(range(10)), GKB_len_list=[2]):
super().__init__(pseudo_label_list, GKB_len_list)

def logic_forward(self, nums):
return sum(nums)

def test_add(reasoner):
res = reasoner.batch_abduce(prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0)
print(res)
res = reasoner.batch_abduce(prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0)
print(res)
res = reasoner.batch_abduce(prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0)
print(res)
res = reasoner.batch_abduce(prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0)
print(res)
res = reasoner.batch_abduce(prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0)
print(res)
print()

print("add_KB with GKB:")
kb = add_GroundKB()
reasoner = ReasonerBase(kb, "confidence")
test_add(reasoner)

print("add_KB without GKB:")
kb = add_KB()
reasoner = ReasonerBase(kb, "confidence")
test_add(reasoner)

print("add_KB without GKB, no cache")
kb = add_KB(use_cache=False)
reasoner = ReasonerBase(kb, "confidence")
test_add(reasoner)

print("PrologBasedKB with add.pl:")
kb = PrologBasedKB(
pseudo_label_list=list(range(10)), pl_file="examples/mnist_add/datasets/add.pl"
)
reasoner = ReasonerBase(kb, "confidence")
test_add(reasoner)

print("PrologBasedKB with add.pl using zoopt:")
kb = PrologBasedKB(
pseudo_label_list=list(range(10)),
pl_file="examples/mnist_add/datasets/add.pl",
)
reasoner = ReasonerBase(kb, "confidence", use_zoopt=True)
test_add(reasoner)

print("add_KB with multiple inputs at once:")
multiple_prob = [
[
[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
],
[
[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
],
]

kb = add_KB()
reasoner = ReasonerBase(kb, "confidence")
res = reasoner.batch_abduce(
multiple_prob,
[[1, 1], [1, 2]],
[4, 8],
max_revision=2,
require_more_revision=0,
)
print(res)
res = reasoner.batch_abduce(
multiple_prob,
[[1, 1], [1, 2]],
[4, 8],
max_revision=2,
require_more_revision=1,
)
print(res)
print()

class HWF_KB(BaseKB):
def __init__(
self,
pseudo_label_list=[
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9",
"+",
"-",
"times",
"div",
],
max_err=1e-3,
):
super().__init__(pseudo_label_list, max_err)

def _valid_candidate(self, formula):
if len(formula) % 2 == 0:
return False
for i in range(len(formula)):
if i % 2 == 0 and formula[i] not in [
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9",
]:
return False
if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]:
return False
return True

def logic_forward(self, formula):
if not self._valid_candidate(formula):
return np.inf
mapping = {str(i): str(i) for i in range(1, 10)}
mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"})
formula = [mapping[f] for f in formula]
return eval("".join(formula))

class HWF_GroundKB(GroundKB):
def __init__(
self,
pseudo_label_list=[
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9",
"+",
"-",
"times",
"div",
],
GKB_len_list=[1, 3, 5, 7],
max_err=1e-3,
):
super().__init__(pseudo_label_list, GKB_len_list, max_err)

def _valid_candidate(self, formula):
if len(formula) % 2 == 0:
return False
for i in range(len(formula)):
if i % 2 == 0 and formula[i] not in [
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9",
]:
return False
if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]:
return False
return True

def logic_forward(self, formula):
if not self._valid_candidate(formula):
return np.inf
mapping = {str(i): str(i) for i in range(1, 10)}
mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"})
formula = [mapping[f] for f in formula]
return eval("".join(formula))

def test_hwf(reasoner):
res = reasoner.batch_abduce(
[None],
[["5", "+", "2"]],
[3],
max_revision=2,
require_more_revision=0,
)
print(res)
res = reasoner.batch_abduce(
[None],
[["5", "+", "9"]],
[65],
max_revision=3,
require_more_revision=0,
)
print(res)
res = reasoner.batch_abduce(
[None],
[["5", "8", "8", "8", "8"]],
[3.17],
max_revision=5,
require_more_revision=3,
)
print(res)
print()

def test_hwf_multiple(reasoner, max_revisions):
res = reasoner.batch_abduce(
[None, None],
[["5", "+", "2"], ["5", "+", "9"]],
[3, 64],
max_revision=max_revisions[0],
require_more_revision=0,
)
print(res)
res = reasoner.batch_abduce(
[None, None],
[["5", "+", "2"], ["5", "+", "9"]],
[3, 64],
max_revision=max_revisions[1],
require_more_revision=0,
)
print(res)
res = reasoner.batch_abduce(
[None, None],
[["5", "+", "2"], ["5", "+", "9"]],
[3, 65],
max_revision=max_revisions[2],
require_more_revision=0,
)
print(res)
print()

print("HWF_KB with GKB, max_err=0.1")
kb = HWF_GroundKB(GKB_len_list=[1, 3, 5], max_err=0.1)
reasoner = ReasonerBase(kb, "hamming")
test_hwf(reasoner)

print("HWF_KB without GKB, max_err=0.1")
kb = HWF_KB(max_err=0.1)
reasoner = ReasonerBase(kb, "hamming")
test_hwf(reasoner)

print("HWF_KB with GKB, max_err=1")
kb = HWF_GroundKB(GKB_len_list=[1, 3, 5], max_err=1)
reasoner = ReasonerBase(kb, "hamming")
test_hwf(reasoner)

print("HWF_KB without GKB, max_err=1")
kb = HWF_KB(max_err=1)
reasoner = ReasonerBase(kb, "hamming")
test_hwf(reasoner)

print("HWF_KB with multiple inputs at once:")
kb = HWF_KB(max_err=0.1)
reasoner = ReasonerBase(kb, "hamming")
test_hwf_multiple(reasoner, max_revisions=[1, 3, 3])

print("max_revision is float")
test_hwf_multiple(reasoner, max_revisions=[0.5, 0.9, 0.9])

class HED_prolog_KB(PrologBasedKB):
def __init__(self, pseudo_label_list, pl_file):
super().__init__(pseudo_label_list, pl_file)

def consist_rule(self, exs, rules):
rules = str(rules).replace("'", "")
pl_query = "eval_inst_feature(%s, %s)." % (exs, rules)
return len(list(self.prolog.query(pl_query))) != 0

def abduce_rules(self, pred_res):
pl_query = "consistent_inst_feature(%s, X)." % pred_res
prolog_result = list(self.prolog.query(pl_query))
if len(prolog_result) == 0:
return None
prolog_rules = prolog_result[0]["X"]
rules = [rule.value for rule in prolog_rules]
return rules

class HED_Reasoner(ReasonerBase):
def __init__(self, kb, dist_func="hamming"):
super().__init__(kb, dist_func, use_zoopt=True)

def _revise_at_idxs(self, pred_res, y, all_revision_flag, idxs):
pred = []
k = []
revision_flag = []
for idx in idxs:
pred.append(pred_res[idx])
k.append(y[idx])
revision_flag += list(all_revision_flag[idx])
revision_idx = np.where(np.array(revision_flag) != 0)[0]
candidate = self.revise_at_idx(pred, k, revision_idx)
return candidate

def zoopt_revision_score(self, symbol_num, pred_res, pred_prob, y, sol):
all_revision_flag = reform_idx(sol.get_x(), pred_res)
lefted_idxs = [i for i in range(len(pred_res))]
candidate_size = []
while lefted_idxs:
idxs = []
idxs.append(lefted_idxs.pop(0))
max_candidate_idxs = []
found = False
for idx in range(-1, len(pred_res)):
if (not idx in idxs) and (idx >= 0):
idxs.append(idx)
candidate = self._revise_at_idxs(pred_res, y, all_revision_flag, idxs)
if len(candidate) == 0:
if len(idxs) > 1:
idxs.pop()
else:
if len(idxs) > len(max_candidate_idxs):
found = True
max_candidate_idxs = idxs.copy()
removed = [i for i in lefted_idxs if i in max_candidate_idxs]
if found:
candidate_size.append(len(removed) + 1)
lefted_idxs = [i for i in lefted_idxs if i not in max_candidate_idxs]
candidate_size.sort()
score = 0
import math

for i in range(0, len(candidate_size)):
score -= math.exp(-i) * candidate_size[i]
return score

def abduce_rules(self, pred_res):
return self.kb.abduce_rules(pred_res)

kb = HED_prolog_KB(
pseudo_label_list=[1, 0, "+", "="],
pl_file="examples/hed/datasets/learn_add.pl",
)
reasoner = HED_Reasoner(kb)
consist_exs = [
[1, 1, "+", 0, "=", 1, 1],
[1, "+", 1, "=", 1, 0],
[0, "+", 0, "=", 0],
]
inconsist_exs1 = [
[1, 1, "+", 0, "=", 1, 1],
[1, "+", 1, "=", 1, 0],
[0, "+", 0, "=", 0],
[0, "+", 0, "=", 1],
]
inconsist_exs2 = [[1, "+", 0, "=", 0], [1, "=", 1, "=", 0], [0, "=", 0, "=", 1, 1]]
rules = ["my_op([0], [0], [0])", "my_op([1], [1], [1, 0])"]

print("HED_kb logic forward")
print(kb.logic_forward(consist_exs))
print(kb.logic_forward(inconsist_exs1), kb.logic_forward(inconsist_exs2))
print()
print("HED_kb consist rule")
print(kb.consist_rule([1, "+", 1, "=", 1, 0], rules))
print(kb.consist_rule([1, "+", 1, "=", 1, 1], rules))
print()

print("HED_Reasoner abduce")
res = reasoner.abduce([[[None]]] * len(consist_exs), consist_exs, [None] * len(consist_exs))
print(res)
res = reasoner.abduce(
[[[None]]] * len(inconsist_exs1), inconsist_exs1, [None] * len(inconsist_exs1)
)
print(res)
res = reasoner.abduce(
[[[None]]] * len(inconsist_exs2), inconsist_exs2, [None] * len(inconsist_exs2)
)
print(res)
print()

print("HED_Reasoner abduce rules")
abduced_rules = reasoner.abduce_rules(consist_exs)
print(abduced_rules)

Loading…
Cancel
Save