Browse Source

[ENH] Let ReasonerBase be compatible with the index style output of the ABLModel.

pull/3/head
Gao Enhao 3 years ago
parent
commit
da4912fcb4
1 changed files with 448 additions and 194 deletions
  1. +448
    -194
      abl/reasoning/reasoner.py

+ 448
- 194
abl/reasoning/reasoner.py View File

@@ -2,25 +2,42 @@ import abc
import numpy as np
from multiprocessing import Pool
from zoopt import Dimension, Objective, Parameter, Opt
from ..utils.utils import confidence_dist, flatten, reform_idx, hamming_dist, float_parameter
from ..utils.utils import (
confidence_dist,
flatten,
reform_idx,
hamming_dist,
float_parameter,
)


class ReasonerBase(abc.ABC):
def __init__(self, kb, dist_func='hamming', zoopt=False):
def __init__(self, kb, dist_func="hamming", mapping=None, zoopt=False):
if not (dist_func == "hamming" or dist_func == "confidence"):
raise NotImplementedError

self.kb = kb
assert dist_func == 'hamming' or dist_func == 'confidence'
self.dist_func = dist_func
self.zoopt = zoopt
if dist_func == 'confidence':
self.mapping = dict(zip(self.kb.pseudo_label_list, list(range(len(self.kb.pseudo_label_list)))))
if mapping is None:
self.mapping = dict(
zip(
list(range(len(self.kb.pseudo_label_list))),
self.kb.pseudo_label_list,
)
)
else:
self.mapping = mapping
self.remapping = dict(zip(self.mapping.values(), self.mapping.keys()))

def _get_cost_list(self, pred_res, pred_res_prob, candidates):
def _get_cost_list(self, pseudo_label, pred_res_prob, candidates):
"""
Get the cost list of candidates based on the distance function.

Parameters
----------
pred_res : list
The predicted result.
pseudo_label : list
List of predicted pseudo labels.
pred_res_prob : list
The predicted result probability.
candidates : list
@@ -31,21 +48,21 @@ class ReasonerBase(abc.ABC):
list
The cost list of candidates.
"""
if self.dist_func == 'hamming':
return hamming_dist(pred_res, candidates)
elif self.dist_func == 'confidence':
candidates = [list(map(lambda x: self.mapping[x], c)) for c in candidates]
if self.dist_func == "hamming":
return hamming_dist(pseudo_label, candidates)
elif self.dist_func == "confidence":
candidates = [list(map(lambda x: self.remapping[x], c)) for c in candidates]
return confidence_dist(pred_res_prob, candidates)

def _get_one_candidate(self, pred_res, pred_res_prob, candidates):
def _get_one_candidate(self, pseudo_label, pred_res_prob, candidates):
"""
Get the best candidate based on the distance function.

Parameters
----------
pred_res : list
The predicted result.
pseudo_label : list
List of predicted pseudo labels.
pred_res_prob : list
The predicted result probability.
candidates : list
@@ -58,23 +75,14 @@ class ReasonerBase(abc.ABC):
"""
if len(candidates) == 0:
return []
elif len(candidates) == 1 or self.zoopt:
elif len(candidates) == 1:
return candidates[0]
else:
cost_list = self._get_cost_list(pred_res, pred_res_prob, candidates)
cost_list = self._get_cost_list(pseudo_label, pred_res_prob, candidates)
candidate = candidates[np.argmin(cost_list)]
return candidate
def _zoopt_revision_score_single(self, sol_x, pred_res, pred_res_prob, y):
revision_idx = np.where(sol_x != 0)[0]
candidates = self.revise_by_idx(pred_res, y, revision_idx)
if len(candidates) > 0:
return np.min(self._get_cost_list(pred_res, pred_res_prob, candidates))
else:
return len(pred_res)
def zoopt_revision_score(self, pred_res, pred_res_prob, y, sol):

def zoopt_revision_score(self, pred_res, pseudo_label, pred_res_prob, y, sol):
"""
Get the revision score for a single solution.

@@ -84,6 +92,8 @@ class ReasonerBase(abc.ABC):
Solution to evaluate.
pred_res : list
List of predicted results.
pseudo_label : list
List of predicted pseudo labels.
pred_res_prob : list
List of probabilities for predicted results.
y : str
@@ -95,23 +105,27 @@ class ReasonerBase(abc.ABC):
The revision score for the given solution.
"""
revision_idx = np.where(sol.get_x() != 0)[0]
candidates = self.revise_by_idx(pred_res, y, revision_idx)
candidates = self.revise_by_idx(pseudo_label, y, revision_idx)
if len(candidates) > 0:
return np.min(self._get_cost_list(pred_res, pred_res_prob, candidates))
return np.min(self._get_cost_list(pseudo_label, pred_res_prob, candidates))
else:
return len(pred_res)
def _constrain_revision_num(self, solution, max_revision_num):
x = solution.get_x()
return max_revision_num - x.sum()

def zoopt_get_solution(self, pred_res, pred_res_prob, y, max_revision_num):
def zoopt_get_solution(
self, pred_res, pseudo_label, pred_res_prob, y, max_revision_num
):
"""Get the optimal solution using the Zoopt library.

Parameters
----------
pred_res : list
List of predicted results.
pseudo_label : list
List of predicted pseudo labels.
pred_res_prob : list
List of probabilities for predicted results.
y : str
@@ -127,21 +141,23 @@ class ReasonerBase(abc.ABC):
length = len(flatten(pred_res))
dimension = Dimension(size=length, regs=[[0, 1]] * length, tys=[False] * length)
objective = Objective(
lambda sol: self.zoopt_revision_score(pred_res, pred_res_prob, y, sol),
lambda sol: self.zoopt_revision_score(
pred_res, pseudo_label, pred_res_prob, y, sol
),
dim=dimension,
constraint=lambda sol: self._constrain_revision_num(sol, max_revision_num),
)
parameter = Parameter(budget=100, intermediate_result=False, autoset=True)
solution = Opt.min(objective, parameter).get_x()
return solution
def revise_by_idx(self, pred_res, y, revision_idx):
def revise_by_idx(self, pseudo_label, y, revision_idx):
"""Get the revisions corresponding to the given indices.

Parameters
----------
pred_res : list
List of predicted results.
pseudo_label : list
List of predicted pseudo labels.
y : str
Ground truth for the predicted results.
revision_idx : array-like
@@ -152,7 +168,7 @@ class ReasonerBase(abc.ABC):
list
The revisions corresponding to the given indices.
"""
return self.kb.revise_by_idx(pred_res, y, revision_idx)
return self.kb.revise_by_idx(pseudo_label, y, revision_idx)

def abduce(self, data, max_revision=-1, require_more_revision=0):
"""Perform abduction on the given data.
@@ -162,7 +178,7 @@ class ReasonerBase(abc.ABC):
data : tuple
Tuple containing the predicted results, predicted result probabilities, and y.
max_revision : int or float, optional
Maximum number of revisions to use. If float, represents the fraction of total revisions to use.
Maximum number of revisions to use. If float, represents the fraction of total revisions to use.
If -1, use all revisions. Defaults to -1.
require_more_revision : int, optional
Number of additional revisions to require. Defaults to 0.
@@ -173,16 +189,22 @@ class ReasonerBase(abc.ABC):
The abduced revisions.
"""
pred_res, pred_res_prob, y = data
pseudo_label = [self.mapping[_idx] for _idx in pred_res]

max_revision_num = float_parameter(max_revision, len(flatten(pred_res)))

if self.zoopt:
solution = self.zoopt_get_solution(pred_res, pred_res_prob, y, max_revision_num)
solution = self.zoopt_get_solution(
pred_res, pseudo_label, pred_res_prob, y, max_revision_num
)
revision_idx = np.where(solution != 0)[0]
candidates = self.revise_by_idx(pred_res, y, revision_idx)
candidates = self.revise_by_idx(pseudo_label, y, revision_idx)
else:
candidates = self.kb.abduce_candidates(pred_res, y, max_revision_num, require_more_revision)
candidates = self.kb.abduce_candidates(
pred_res, y, max_revision_num, require_more_revision
)

candidate = self._get_one_candidate(pred_res, pred_res_prob, candidates)
candidate = self._get_one_candidate(pseudo_label, pred_res_prob, candidates)
return candidate

def batch_abduce(self, Z, Y, max_revision=-1, require_more_revision=0):
@@ -195,7 +217,7 @@ class ReasonerBase(abc.ABC):
Y : list
List of ground truths.
max_revision : int or float, optional
Maximum number of revisions to use. If float, represents the fraction of total revisions to use.
Maximum number of revisions to use. If float, represents the fraction of total revisions to use.
If -1, use all revisions. Defaults to -1.
require_more_revision : int, optional
Number of additional revisions to require. Defaults to 0.
@@ -205,8 +227,11 @@ class ReasonerBase(abc.ABC):
list
The abduced revisions.
"""
return [self.abduce((z, prob, y), max_revision, require_more_revision) for z, prob, y in zip(Z['cls'], Z['prob'], Y)]
return [
self.abduce((z, prob, y), max_revision, require_more_revision)
for z, prob, y in zip(Z["label"], Z["prob"], Y)
]

# def _batch_abduce_helper(self, args):
# z, prob, y, max_revision, require_more_revision = args
# return self.abduce((z, prob, y), max_revision, require_more_revision)
@@ -215,120 +240,224 @@ class ReasonerBase(abc.ABC):
# with Pool(processes=os.cpu_count()) as pool:
# results = pool.map(self._batch_abduce_helper, [(z, prob, y, max_revision, require_more_revision) for z, prob, y in zip(Z['cls'], Z['prob'], Y)])
# return results
def __call__(self, Z, Y, max_revision=-1, require_more_revision=0):
return self.batch_abduce(Z, Y, max_revision, require_more_revision)

if __name__ == '__main__':

if __name__ == "__main__":
from kb import KBBase, prolog_KB
prob1 = [[[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]]
prob2 = [[[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]]

prob1 = [
[
[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
]
]
prob2 = [
[
[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
]
]

class add_KB(KBBase):
def __init__(self, pseudo_label_list=list(range(10)), len_list=[2], GKB_flag=False, max_err=0, use_cache=True):
def __init__(
self,
pseudo_label_list=list(range(10)),
len_list=[2],
GKB_flag=False,
max_err=0,
use_cache=True,
):
super().__init__(pseudo_label_list, len_list, GKB_flag, max_err, use_cache)

def logic_forward(self, nums):
return sum(nums)
print('add_KB with GKB:')

print("add_KB with GKB:")
kb = add_KB(GKB_flag=True)
reasoner = ReasonerBase(kb, 'confidence')
res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [8], max_revision=2, require_more_revision=0)
reasoner = ReasonerBase(kb, "confidence")
res = reasoner.batch_abduce(
{"cls": [[1, 1]], "prob": prob1}, [8], max_revision=2, require_more_revision=0
)
print(res)
res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob2}, [8], max_revision=2, require_more_revision=0)
res = reasoner.batch_abduce(
{"cls": [[1, 1]], "prob": prob2}, [8], max_revision=2, require_more_revision=0
)
print(res)
res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_revision=2, require_more_revision=0)
res = reasoner.batch_abduce(
{"cls": [[1, 1]], "prob": prob1}, [17], max_revision=2, require_more_revision=0
)
print(res)
res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_revision=1, require_more_revision=0)
res = reasoner.batch_abduce(
{"cls": [[1, 1]], "prob": prob1}, [17], max_revision=1, require_more_revision=0
)
print(res)
res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [20], max_revision=2, require_more_revision=0)
res = reasoner.batch_abduce(
{"cls": [[1, 1]], "prob": prob1}, [20], max_revision=2, require_more_revision=0
)
print(res)
print()
print('add_KB without GKB:')
print("add_KB without GKB:")
kb = add_KB()
reasoner = ReasonerBase(kb, 'confidence')
res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [8], max_revision=2, require_more_revision=0)
reasoner = ReasonerBase(kb, "confidence")
res = reasoner.batch_abduce(
{"cls": [[1, 1]], "prob": prob1}, [8], max_revision=2, require_more_revision=0
)
print(res)
res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob2}, [8], max_revision=2, require_more_revision=0)
res = reasoner.batch_abduce(
{"cls": [[1, 1]], "prob": prob2}, [8], max_revision=2, require_more_revision=0
)
print(res)
res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_revision=2, require_more_revision=0)
res = reasoner.batch_abduce(
{"cls": [[1, 1]], "prob": prob1}, [17], max_revision=2, require_more_revision=0
)
print(res)
res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_revision=1, require_more_revision=0)
res = reasoner.batch_abduce(
{"cls": [[1, 1]], "prob": prob1}, [17], max_revision=1, require_more_revision=0
)
print(res)
res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [20], max_revision=2, require_more_revision=0)
res = reasoner.batch_abduce(
{"cls": [[1, 1]], "prob": prob1}, [20], max_revision=2, require_more_revision=0
)
print(res)
print()
print('add_KB without GKB:, no cache')
print("add_KB without GKB:, no cache")
kb = add_KB(use_cache=False)
reasoner = ReasonerBase(kb, 'confidence')
res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [8], max_revision=2, require_more_revision=0)
reasoner = ReasonerBase(kb, "confidence")
res = reasoner.batch_abduce(
{"cls": [[1, 1]], "prob": prob1}, [8], max_revision=2, require_more_revision=0
)
print(res)
res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob2}, [8], max_revision=2, require_more_revision=0)
res = reasoner.batch_abduce(
{"cls": [[1, 1]], "prob": prob2}, [8], max_revision=2, require_more_revision=0
)
print(res)
res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_revision=2, require_more_revision=0)
res = reasoner.batch_abduce(
{"cls": [[1, 1]], "prob": prob1}, [17], max_revision=2, require_more_revision=0
)
print(res)
res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_revision=1, require_more_revision=0)
res = reasoner.batch_abduce(
{"cls": [[1, 1]], "prob": prob1}, [17], max_revision=1, require_more_revision=0
)
print(res)
res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [20], max_revision=2, require_more_revision=0)
res = reasoner.batch_abduce(
{"cls": [[1, 1]], "prob": prob1}, [20], max_revision=2, require_more_revision=0
)
print(res)
print()
print('prolog_KB with add.pl:')
kb = prolog_KB(pseudo_label_list=list(range(10)), pl_file='../examples/mnist_add/datasets/add.pl')
reasoner = ReasonerBase(kb, 'confidence')
res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [8], max_revision=2, require_more_revision=0)

print("prolog_KB with add.pl:")
kb = prolog_KB(
pseudo_label_list=list(range(10)),
pl_file="../examples/mnist_add/datasets/add.pl",
)
reasoner = ReasonerBase(kb, "confidence")
res = reasoner.batch_abduce(
{"cls": [[1, 1]], "prob": prob1}, [8], max_revision=2, require_more_revision=0
)
print(res)
res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob2}, [8], max_revision=2, require_more_revision=0)
res = reasoner.batch_abduce(
{"cls": [[1, 1]], "prob": prob2}, [8], max_revision=2, require_more_revision=0
)
print(res)
res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_revision=2, require_more_revision=0)
res = reasoner.batch_abduce(
{"cls": [[1, 1]], "prob": prob1}, [17], max_revision=2, require_more_revision=0
)
print(res)
res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_revision=1, require_more_revision=0)
res = reasoner.batch_abduce(
{"cls": [[1, 1]], "prob": prob1}, [17], max_revision=1, require_more_revision=0
)
print(res)
res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [20], max_revision=2, require_more_revision=0)
res = reasoner.batch_abduce(
{"cls": [[1, 1]], "prob": prob1}, [20], max_revision=2, require_more_revision=0
)
print(res)
print()

print('prolog_KB with add.pl using zoopt:')
kb = prolog_KB(pseudo_label_list=list(range(10)), pl_file='../examples/mnist_add/datasets/add.pl')
reasoner = ReasonerBase(kb, 'confidence', zoopt=True)
res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [8], max_revision=2, require_more_revision=0)
print("prolog_KB with add.pl using zoopt:")
kb = prolog_KB(
pseudo_label_list=list(range(10)),
pl_file="../examples/mnist_add/datasets/add.pl",
)
reasoner = ReasonerBase(kb, "confidence", zoopt=True)
res = reasoner.batch_abduce(
{"cls": [[1, 1]], "prob": prob1}, [8], max_revision=2, require_more_revision=0
)
print(res)
res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob2}, [8], max_revision=2, require_more_revision=0)
res = reasoner.batch_abduce(
{"cls": [[1, 1]], "prob": prob2}, [8], max_revision=2, require_more_revision=0
)
print(res)
res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_revision=2, require_more_revision=0)
res = reasoner.batch_abduce(
{"cls": [[1, 1]], "prob": prob1}, [17], max_revision=2, require_more_revision=0
)
print(res)
res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_revision=1, require_more_revision=0)
res = reasoner.batch_abduce(
{"cls": [[1, 1]], "prob": prob1}, [17], max_revision=1, require_more_revision=0
)
print(res)
res = reasoner.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [20], max_revision=2, require_more_revision=0)
res = reasoner.batch_abduce(
{"cls": [[1, 1]], "prob": prob1}, [20], max_revision=2, require_more_revision=0
)
print(res)
print()
print('add_KB with multiple inputs at once:')
multiple_prob = [[[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]],
[[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]]

print("add_KB with multiple inputs at once:")
multiple_prob = [
[
[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
],
[
[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
],
]

kb = add_KB()
reasoner = ReasonerBase(kb, 'confidence')
res = reasoner.batch_abduce({'cls':[[1, 1], [1, 2]], 'prob':multiple_prob}, [4, 8], max_revision=2, require_more_revision=0)
print(res)
res = reasoner.batch_abduce({'cls':[[1, 1], [1, 2]], 'prob':multiple_prob}, [4, 8], max_revision=2, require_more_revision=1)
reasoner = ReasonerBase(kb, "confidence")
res = reasoner.batch_abduce(
{"cls": [[1, 1], [1, 2]], "prob": multiple_prob},
[4, 8],
max_revision=2,
require_more_revision=0,
)
print(res)
res = reasoner.batch_abduce(
{"cls": [[1, 1], [1, 2]], "prob": multiple_prob},
[4, 8],
max_revision=2,
require_more_revision=1,
)
print(res)
print()
class HWF_KB(KBBase):
def __init__(
self,
pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'],
self,
pseudo_label_list=[
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9",
"+",
"-",
"times",
"div",
],
len_list=[1, 3, 5, 7],
GKB_flag=False,
max_err=1e-3,
use_cache=True
use_cache=True,
):
super().__init__(pseudo_label_list, len_list, GKB_flag, max_err, use_cache)

@@ -336,9 +465,19 @@ if __name__ == '__main__':
if len(formula) % 2 == 0:
return False
for i in range(len(formula)):
if i % 2 == 0 and formula[i] not in ['1', '2', '3', '4', '5', '6', '7', '8', '9']:
if i % 2 == 0 and formula[i] not in [
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9",
]:
return False
if i % 2 != 0 and formula[i] not in ['+', '-', 'times', 'div']:
if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]:
return False
return True

@@ -346,91 +485,183 @@ if __name__ == '__main__':
if not self._valid_candidate(formula):
return np.inf
mapping = {str(i): str(i) for i in range(1, 10)}
mapping.update({'+': '+', '-': '-', 'times': '*', 'div': '/'})
mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"})
formula = [mapping[f] for f in formula]
return eval(''.join(formula))
print('HWF_KB with GKB, max_err=0.1')
kb = HWF_KB(len_list=[1, 3, 5], GKB_flag=True, max_err = 0.1)
reasoner = ReasonerBase(kb, 'hamming')
res = reasoner.batch_abduce({'cls':[['5', '+', '2']], 'prob':[None]}, [3], max_revision=2, require_more_revision=0)
print(res)
res = reasoner.batch_abduce({'cls':[['5', '+', '9']], 'prob':[None]}, [65], max_revision=3, require_more_revision=0)
print(res)
res = reasoner.batch_abduce({'cls':[['5', '8', '8', '8', '8']], 'prob':[None]}, [3.17], max_revision=5, require_more_revision=3)
return eval("".join(formula))

print("HWF_KB with GKB, max_err=0.1")
kb = HWF_KB(len_list=[1, 3, 5], GKB_flag=True, max_err=0.1)
reasoner = ReasonerBase(kb, "hamming")
res = reasoner.batch_abduce(
{"cls": [["5", "+", "2"]], "prob": [None]},
[3],
max_revision=2,
require_more_revision=0,
)
print(res)
res = reasoner.batch_abduce(
{"cls": [["5", "+", "9"]], "prob": [None]},
[65],
max_revision=3,
require_more_revision=0,
)
print(res)
res = reasoner.batch_abduce(
{"cls": [["5", "8", "8", "8", "8"]], "prob": [None]},
[3.17],
max_revision=5,
require_more_revision=3,
)
print(res)
print()
print('HWF_KB without GKB, max_err=0.1')
kb = HWF_KB(len_list=[1, 3, 5], max_err = 0.1)
reasoner = ReasonerBase(kb, 'hamming')
res = reasoner.batch_abduce({'cls':[['5', '+', '2']], 'prob':[None]}, [3], max_revision=2, require_more_revision=0)
print(res)
res = reasoner.batch_abduce({'cls':[['5', '+', '9']], 'prob':[None]}, [65], max_revision=3, require_more_revision=0)
print(res)
res = reasoner.batch_abduce({'cls':[['5', '8', '8', '8', '8']], 'prob':[None]}, [3.17], max_revision=5, require_more_revision=3)

print("HWF_KB without GKB, max_err=0.1")
kb = HWF_KB(len_list=[1, 3, 5], max_err=0.1)
reasoner = ReasonerBase(kb, "hamming")
res = reasoner.batch_abduce(
{"cls": [["5", "+", "2"]], "prob": [None]},
[3],
max_revision=2,
require_more_revision=0,
)
print(res)
res = reasoner.batch_abduce(
{"cls": [["5", "+", "9"]], "prob": [None]},
[65],
max_revision=3,
require_more_revision=0,
)
print(res)
res = reasoner.batch_abduce(
{"cls": [["5", "8", "8", "8", "8"]], "prob": [None]},
[3.17],
max_revision=5,
require_more_revision=3,
)
print(res)
print()
print('HWF_KB with GKB, max_err=1')
kb = HWF_KB(len_list=[1, 3, 5], GKB_flag=True, max_err = 1)
reasoner = ReasonerBase(kb, 'hamming')
res = reasoner.batch_abduce({'cls':[['5', '+', '9']], 'prob':[None]}, [65], max_revision=3, require_more_revision=0)
print(res)
res = reasoner.batch_abduce({'cls':[['5', '+', '2']], 'prob':[None]}, [1.67], max_revision=3, require_more_revision=0)
print(res)
res = reasoner.batch_abduce({'cls':[['5', '8', '8', '8', '8']], 'prob':[None]}, [3.17], max_revision=5, require_more_revision=3)

print("HWF_KB with GKB, max_err=1")
kb = HWF_KB(len_list=[1, 3, 5], GKB_flag=True, max_err=1)
reasoner = ReasonerBase(kb, "hamming")
res = reasoner.batch_abduce(
{"cls": [["5", "+", "9"]], "prob": [None]},
[65],
max_revision=3,
require_more_revision=0,
)
print(res)
res = reasoner.batch_abduce(
{"cls": [["5", "+", "2"]], "prob": [None]},
[1.67],
max_revision=3,
require_more_revision=0,
)
print(res)
res = reasoner.batch_abduce(
{"cls": [["5", "8", "8", "8", "8"]], "prob": [None]},
[3.17],
max_revision=5,
require_more_revision=3,
)
print(res)
print()
print('HWF_KB without GKB, max_err=1')
kb = HWF_KB(len_list=[1, 3, 5], max_err = 1)
reasoner = ReasonerBase(kb, 'hamming')
res = reasoner.batch_abduce({'cls':[['5', '+', '9']], 'prob':[None]}, [65], max_revision=3, require_more_revision=0)
print(res)
res = reasoner.batch_abduce({'cls':[['5', '+', '2']], 'prob':[None]}, [1.67], max_revision=3, require_more_revision=0)
print(res)
res = reasoner.batch_abduce({'cls':[['5', '8', '8', '8', '8']], 'prob':[None]}, [3.17], max_revision=5, require_more_revision=3)

print("HWF_KB without GKB, max_err=1")
kb = HWF_KB(len_list=[1, 3, 5], max_err=1)
reasoner = ReasonerBase(kb, "hamming")
res = reasoner.batch_abduce(
{"cls": [["5", "+", "9"]], "prob": [None]},
[65],
max_revision=3,
require_more_revision=0,
)
print(res)
res = reasoner.batch_abduce(
{"cls": [["5", "+", "2"]], "prob": [None]},
[1.67],
max_revision=3,
require_more_revision=0,
)
print(res)
res = reasoner.batch_abduce(
{"cls": [["5", "8", "8", "8", "8"]], "prob": [None]},
[3.17],
max_revision=5,
require_more_revision=3,
)
print(res)
print()
print('HWF_KB with multiple inputs at once:')
kb = HWF_KB(len_list=[1, 3, 5], max_err = 0.1)
reasoner = ReasonerBase(kb, 'hamming')
res = reasoner.batch_abduce({'cls':[['5', '+', '2'], ['5', '+', '9']], 'prob':[None, None]}, [3, 64], max_revision=1, require_more_revision=0)
print(res)
res = reasoner.batch_abduce({'cls':[['5', '+', '2'], ['5', '+', '9']], 'prob':[None, None]}, [3, 64], max_revision=3, require_more_revision=0)
print(res)
res = reasoner.batch_abduce({'cls':[['5', '+', '2'], ['5', '+', '9']], 'prob':[None, None]}, [3, 65], max_revision=3, require_more_revision=0)

print("HWF_KB with multiple inputs at once:")
kb = HWF_KB(len_list=[1, 3, 5], max_err=0.1)
reasoner = ReasonerBase(kb, "hamming")
res = reasoner.batch_abduce(
{"cls": [["5", "+", "2"], ["5", "+", "9"]], "prob": [None, None]},
[3, 64],
max_revision=1,
require_more_revision=0,
)
print(res)
res = reasoner.batch_abduce(
{"cls": [["5", "+", "2"], ["5", "+", "9"]], "prob": [None, None]},
[3, 64],
max_revision=3,
require_more_revision=0,
)
print(res)
res = reasoner.batch_abduce(
{"cls": [["5", "+", "2"], ["5", "+", "9"]], "prob": [None, None]},
[3, 65],
max_revision=3,
require_more_revision=0,
)
print(res)
print()
print('max_revision is float')
res = reasoner.batch_abduce({'cls':[['5', '+', '2'], ['5', '+', '9']], 'prob':[None, None]}, [3, 64], max_revision=0.5, require_more_revision=0)
print(res)
res = reasoner.batch_abduce({'cls':[['5', '+', '2'], ['5', '+', '9']], 'prob':[None, None]}, [3, 64], max_revision=0.9, require_more_revision=0)
print("max_revision is float")
res = reasoner.batch_abduce(
{"cls": [["5", "+", "2"], ["5", "+", "9"]], "prob": [None, None]},
[3, 64],
max_revision=0.5,
require_more_revision=0,
)
print(res)
res = reasoner.batch_abduce(
{"cls": [["5", "+", "2"], ["5", "+", "9"]], "prob": [None, None]},
[3, 64],
max_revision=0.9,
require_more_revision=0,
)
print(res)
print()
class HED_prolog_KB(prolog_KB):
def __init__(self, pseudo_label_list, pl_file):
super().__init__(pseudo_label_list, pl_file)
def consist_rule(self, exs, rules):
rules = str(rules).replace("\'","")
return len(list(self.prolog.query("eval_inst_feature(%s, %s)." % (exs, rules)))) != 0
rules = str(rules).replace("'", "")
return (
len(
list(self.prolog.query("eval_inst_feature(%s, %s)." % (exs, rules)))
)
!= 0
)

def abduce_rules(self, pred_res):
prolog_result = list(self.prolog.query("consistent_inst_feature(%s, X)." % pred_res))
prolog_result = list(
self.prolog.query("consistent_inst_feature(%s, X)." % pred_res)
)
if len(prolog_result) == 0:
return None
prolog_rules = prolog_result[0]['X']
prolog_rules = prolog_result[0]["X"]
rules = [rule.value for rule in prolog_rules]
return rules
class HED_Reasoner(ReasonerBase):
def __init__(self, kb, dist_func='hamming'):
def __init__(self, kb, dist_func="hamming"):
super().__init__(kb, dist_func, zoopt=True)
def _revise_by_idxs(self, pred_res, y, all_revision_flag, idxs):
pred = []
k = []
@@ -439,14 +670,14 @@ if __name__ == '__main__':
pred.append(pred_res[idx])
k.append(y[idx])
revision_flag += list(all_revision_flag[idx])
revision_idx = np.where(np.array(revision_flag) != 0)[0]
revision_idx = np.where(np.array(revision_flag) != 0)[0]
candidate = self.revise_by_idx(pred, k, revision_idx)
return candidate
def zoopt_revision_score(self, pred_res, pred_res_prob, y, sol):
def zoopt_revision_score(self, pred_res, pred_res_prob, y, sol):
all_revision_flag = reform_idx(sol.get_x(), pred_res)
lefted_idxs = [i for i in range(len(pred_res))]
candidate_size = []
candidate_size = []
while lefted_idxs:
idxs = []
idxs.append(lefted_idxs.pop(0))
@@ -455,21 +686,26 @@ if __name__ == '__main__':
for idx in range(-1, len(pred_res)):
if (not idx in idxs) and (idx >= 0):
idxs.append(idx)
candidate = self._revise_by_idxs(pred_res, y, all_revision_flag, idxs)
candidate = self._revise_by_idxs(
pred_res, y, all_revision_flag, idxs
)
if len(candidate) == 0:
if len(idxs) > 1:
idxs.pop()
else:
if len(idxs) > len(max_candidate_idxs):
found = True
max_candidate_idxs = idxs.copy()
max_candidate_idxs = idxs.copy()
removed = [i for i in lefted_idxs if i in max_candidate_idxs]
if found:
candidate_size.append(len(removed) + 1)
lefted_idxs = [i for i in lefted_idxs if i not in max_candidate_idxs]
lefted_idxs = [
i for i in lefted_idxs if i not in max_candidate_idxs
]
candidate_size.sort()
score = 0
import math

for i in range(0, len(candidate_size)):
score -= math.exp(-i) * candidate_size[i]
return score
@@ -477,31 +713,49 @@ if __name__ == '__main__':
def abduce_rules(self, pred_res):
return self.kb.abduce_rules(pred_res)

kb = HED_prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='../examples/hed/datasets/learn_add.pl')
kb = HED_prolog_KB(
pseudo_label_list=[1, 0, "+", "="],
pl_file="../examples/hed/datasets/learn_add.pl",
)
reasoner = HED_Reasoner(kb)
consist_exs = [[1, 1, '+', 0, '=', 1, 1], [1, '+', 1, '=', 1, 0], [0, '+', 0, '=', 0]]
inconsist_exs1 = [[1, 1, '+', 0, '=', 1, 1], [1, '+', 1, '=', 1, 0], [0, '+', 0, '=', 0], [0, '+', 0, '=', 1]]
inconsist_exs2 = [[1, '+', 0, '=', 0], [1, '=', 1, '=', 0], [0, '=', 0, '=', 1, 1]]
rules = ['my_op([0], [0], [0])', 'my_op([1], [1], [1, 0])']

print('HED_kb logic forward')
consist_exs = [
[1, 1, "+", 0, "=", 1, 1],
[1, "+", 1, "=", 1, 0],
[0, "+", 0, "=", 0],
]
inconsist_exs1 = [
[1, 1, "+", 0, "=", 1, 1],
[1, "+", 1, "=", 1, 0],
[0, "+", 0, "=", 0],
[0, "+", 0, "=", 1],
]
inconsist_exs2 = [[1, "+", 0, "=", 0], [1, "=", 1, "=", 0], [0, "=", 0, "=", 1, 1]]
rules = ["my_op([0], [0], [0])", "my_op([1], [1], [1, 0])"]

print("HED_kb logic forward")
print(kb.logic_forward(consist_exs))
print(kb.logic_forward(inconsist_exs1), kb.logic_forward(inconsist_exs2))
print()
print('HED_kb consist rule')
print(kb.consist_rule([1, '+', 1, '=', 1, 0], rules))
print(kb.consist_rule([1, '+', 1, '=', 1, 1], rules))
print("HED_kb consist rule")
print(kb.consist_rule([1, "+", 1, "=", 1, 0], rules))
print(kb.consist_rule([1, "+", 1, "=", 1, 1], rules))
print()

print('HED_Reasoner abduce')
res = reasoner.abduce((consist_exs, [[[None]]] * len(consist_exs), [None] * len(consist_exs)))
print("HED_Reasoner abduce")
res = reasoner.abduce(
(consist_exs, [[[None]]] * len(consist_exs), [None] * len(consist_exs))
)
print(res)
res = reasoner.abduce((inconsist_exs1, [[[None]]] * len(inconsist_exs1), [None] * len(inconsist_exs1)))
res = reasoner.abduce(
(inconsist_exs1, [[[None]]] * len(inconsist_exs1), [None] * len(inconsist_exs1))
)
print(res)
res = reasoner.abduce((inconsist_exs2, [[[None]]] * len(inconsist_exs2), [None] * len(inconsist_exs2)))
res = reasoner.abduce(
(inconsist_exs2, [[[None]]] * len(inconsist_exs2), [None] * len(inconsist_exs2))
)
print(res)
print()

print('HED_Reasoner abduce rules')
print("HED_Reasoner abduce rules")
abduced_rules = reasoner.abduce_rules(consist_exs)
print(abduced_rules)
print(abduced_rules)

Loading…
Cancel
Save