Browse Source

[MNT] use black to reformat reasoner.py

pull/3/head
Gao Enhao 2 years ago
parent
commit
a816bf0b74
1 changed files with 208 additions and 68 deletions
  1. +208
    -68
      abl/reasoning/reasoner.py

+ 208
- 68
abl/reasoning/reasoner.py View File

@@ -1,5 +1,4 @@
import numpy as np
from multiprocessing import Pool
from zoopt import Dimension, Objective, Parameter, Opt
from ..utils.utils import (
confidence_dist,
@@ -10,7 +9,7 @@ from ..utils.utils import (
)


class ReasonerBase():
class ReasonerBase:
def __init__(self, kb, dist_func="hamming", mapping=None, use_zoopt=False):
"""
Root class for all reasoner in the ABL system.
@@ -31,15 +30,17 @@ class ReasonerBase():
NotImplementedError
If the specified distance function is neither "hamming" nor "confidence".
"""
if not (dist_func == "hamming" or dist_func == "confidence"):
raise NotImplementedError # Only hamming or confidence distance is available.
raise NotImplementedError # Only hamming or confidence distance is available.

self.kb = kb
self.dist_func = dist_func
self.use_zoopt = use_zoopt
if mapping is None:
self.mapping = {index: label for index, label in enumerate(self.kb.pseudo_label_list)}
self.mapping = {
index: label for index, label in enumerate(self.kb.pseudo_label_list)
}
else:
self.mapping = mapping
self.remapping = dict(zip(self.mapping.values(), self.mapping.keys()))
@@ -130,7 +131,9 @@ class ReasonerBase():
x = solution.get_x()
return max_revision_num - x.sum()

def zoopt_get_solution(self, symbol_num, pred_pseudo_label, pred_prob, y, max_revision_num):
def zoopt_get_solution(
self, symbol_num, pred_pseudo_label, pred_prob, y, max_revision_num
):
"""Get the optimal solution using the Zoopt library.

Parameters
@@ -151,9 +154,13 @@ class ReasonerBase():
array-like
The optimal solution, i.e., where to revise predict pseudo label.
"""
dimension = Dimension(size=symbol_num, regs=[[0, 1]] * symbol_num, tys=[False] * symbol_num)
dimension = Dimension(
size=symbol_num, regs=[[0, 1]] * symbol_num, tys=[False] * symbol_num
)
objective = Objective(
lambda sol: self.zoopt_revision_score(symbol_num, pred_pseudo_label, pred_prob, y, sol),
lambda sol: self.zoopt_revision_score(
symbol_num, pred_pseudo_label, pred_prob, y, sol
),
dim=dimension,
constraint=lambda sol: self._constrain_revision_num(sol, max_revision_num),
)
@@ -181,7 +188,9 @@ class ReasonerBase():
"""
return self.kb.revise_by_idx(pred_pseudo_label, y, revision_idx)

def abduce(self, pred_prob, pred_pseudo_label, y, max_revision=-1, require_more_revision=0):
def abduce(
self, pred_prob, pred_pseudo_label, y, max_revision=-1, require_more_revision=0
):
"""
Perform revision by abduction on the given data.

@@ -208,16 +217,22 @@ class ReasonerBase():
max_revision_num = float_parameter(max_revision, symbol_num)

if self.use_zoopt:
solution = self.zoopt_get_solution(symbol_num, pred_pseudo_label, pred_prob, y, max_revision_num)
solution = self.zoopt_get_solution(
symbol_num, pred_pseudo_label, pred_prob, y, max_revision_num
)
revision_idx = np.where(solution != 0)[0]
candidates = self.revise_by_idx(pred_pseudo_label, y, revision_idx)
else:
candidates = self.kb.abduce_candidates(pred_pseudo_label, y, max_revision_num, require_more_revision)
candidates = self.kb.abduce_candidates(
pred_pseudo_label, y, max_revision_num, require_more_revision
)

candidate = self._get_one_candidate(pred_pseudo_label, pred_prob, candidates)
return candidate

def batch_abduce(self, pred_prob, pred_pseudo_label, Y, max_revision=-1, require_more_revision=0):
def batch_abduce(
self, pred_prob, pred_pseudo_label, Y, max_revision=-1, require_more_revision=0
):
"""
Perform abduction on the given data in batches.

@@ -240,9 +255,14 @@ class ReasonerBase():
list
The abduced revisions in batches.
"""
return [self.abduce(_pred_prob, _pred_pseudo_label, _Y, max_revision, require_more_revision)
for _pred_prob, _pred_pseudo_label, _Y in zip(pred_prob, pred_pseudo_label, Y)]

return [
self.abduce(
_pred_prob, _pred_pseudo_label, _Y, max_revision, require_more_revision
)
for _pred_prob, _pred_pseudo_label, _Y in zip(
pred_prob, pred_pseudo_label, Y
)
]

# def _batch_abduce_helper(self, args):
# z, prob, y, max_revision, require_more_revision = args
@@ -253,8 +273,12 @@ class ReasonerBase():
# results = pool.map(self._batch_abduce_helper, [(z, prob, y, max_revision, require_more_revision) for z, prob, y in zip(Z['cls'], Z['prob'], Y)])
# return results

def __call__(self, pred_prob, pred_pseudo_label, Y, max_revision=-1, require_more_revision=0):
return self.batch_abduce(pred_prob, pred_pseudo_label, Y, max_revision, require_more_revision)
def __call__(
self, pred_prob, pred_pseudo_label, Y, max_revision=-1, require_more_revision=0
):
return self.batch_abduce(
pred_prob, pred_pseudo_label, Y, max_revision, require_more_revision
)


if __name__ == "__main__":
@@ -282,7 +306,9 @@ if __name__ == "__main__":
max_err=0,
use_cache=True,
):
super().__init__(pseudo_label_list, prebuild_GKB, GKB_len_list, max_err, use_cache)
super().__init__(
pseudo_label_list, prebuild_GKB, GKB_len_list, max_err, use_cache
)

def logic_forward(self, nums):
return sum(nums)
@@ -290,45 +316,75 @@ if __name__ == "__main__":
print("add_KB with GKB:")
kb = add_KB(prebuild_GKB=True)
reasoner = ReasonerBase(kb, "confidence")
res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0)
res = reasoner.batch_abduce(
[[1, 1]], prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0
)
print(res)
res = reasoner.batch_abduce([[1, 1]], prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0)
res = reasoner.batch_abduce(
[[1, 1]], prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0
)
print(res)
res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0)
res = reasoner.batch_abduce(
[[1, 1]], prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0
)
print(res)
res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0)
res = reasoner.batch_abduce(
[[1, 1]], prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0
)
print(res)
res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0)
res = reasoner.batch_abduce(
[[1, 1]], prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0
)
print(res)
print()

print("add_KB without GKB:")
kb = add_KB()
reasoner = ReasonerBase(kb, "confidence")
res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0)
res = reasoner.batch_abduce(
[[1, 1]], prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0
)
print(res)
res = reasoner.batch_abduce([[1, 1]], prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0)
res = reasoner.batch_abduce(
[[1, 1]], prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0
)
print(res)
res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0)
res = reasoner.batch_abduce(
[[1, 1]], prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0
)
print(res)
res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0)
res = reasoner.batch_abduce(
[[1, 1]], prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0
)
print(res)
res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0)
res = reasoner.batch_abduce(
[[1, 1]], prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0
)
print(res)
print()

print("add_KB without GKB:, no cache")
kb = add_KB(use_cache=False)
reasoner = ReasonerBase(kb, "confidence")
res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0)
res = reasoner.batch_abduce(
[[1, 1]], prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0
)
print(res)
res = reasoner.batch_abduce([[1, 1]], prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0)
res = reasoner.batch_abduce(
[[1, 1]], prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0
)
print(res)
res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0)
res = reasoner.batch_abduce(
[[1, 1]], prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0
)
print(res)
res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0)
res = reasoner.batch_abduce(
[[1, 1]], prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0
)
print(res)
res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0)
res = reasoner.batch_abduce(
[[1, 1]], prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0
)
print(res)
print()

@@ -338,15 +394,25 @@ if __name__ == "__main__":
pl_file="examples/mnist_add/datasets/add.pl",
)
reasoner = ReasonerBase(kb, "confidence")
res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0)
res = reasoner.batch_abduce(
[[1, 1]], prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0
)
print(res)
res = reasoner.batch_abduce([[1, 1]], prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0)
res = reasoner.batch_abduce(
[[1, 1]], prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0
)
print(res)
res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0)
res = reasoner.batch_abduce(
[[1, 1]], prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0
)
print(res)
res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0)
res = reasoner.batch_abduce(
[[1, 1]], prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0
)
print(res)
res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0)
res = reasoner.batch_abduce(
[[1, 1]], prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0
)
print(res)
print()

@@ -356,15 +422,25 @@ if __name__ == "__main__":
pl_file="examples/mnist_add/datasets/add.pl",
)
reasoner = ReasonerBase(kb, "confidence", use_zoopt=True)
res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0)
res = reasoner.batch_abduce(
[[1, 1]], prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0
)
print(res)
res = reasoner.batch_abduce([[1, 1]], prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0)
res = reasoner.batch_abduce(
[[1, 1]], prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0
)
print(res)
res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0)
res = reasoner.batch_abduce(
[[1, 1]], prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0
)
print(res)
res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0)
res = reasoner.batch_abduce(
[[1, 1]], prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0
)
print(res)
res = reasoner.batch_abduce([[1, 1]], prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0)
res = reasoner.batch_abduce(
[[1, 1]], prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0
)
print(res)
print()

@@ -383,13 +459,19 @@ if __name__ == "__main__":
kb = add_KB()
reasoner = ReasonerBase(kb, "confidence")
res = reasoner.batch_abduce(
[[1, 1], [1, 2]], multiple_prob, [[1, 1], [1, 2]], [4, 8],
[[1, 1], [1, 2]],
multiple_prob,
[[1, 1], [1, 2]],
[4, 8],
max_revision=2,
require_more_revision=0,
)
print(res)
res = reasoner.batch_abduce(
[[1, 1], [1, 2]], multiple_prob, [[1, 1], [1, 2]], [4, 8],
[[1, 1], [1, 2]],
multiple_prob,
[[1, 1], [1, 2]],
[4, 8],
max_revision=2,
require_more_revision=1,
)
@@ -419,7 +501,9 @@ if __name__ == "__main__":
max_err=1e-3,
use_cache=True,
):
super().__init__(pseudo_label_list, prebuild_GKB, GKB_len_list, max_err, use_cache)
super().__init__(
pseudo_label_list, prebuild_GKB, GKB_len_list, max_err, use_cache
)

def _valid_candidate(self, formula):
if len(formula) % 2 == 0:
@@ -453,19 +537,28 @@ if __name__ == "__main__":
kb = HWF_KB(prebuild_GKB=True, GKB_len_list=[1, 3, 5], max_err=0.1)
reasoner = ReasonerBase(kb, "hamming")
res = reasoner.batch_abduce(
[["5", "+", "2"]], [None], [[5,10,2]], [3],
[["5", "+", "2"]],
[None],
[[5, 10, 2]],
[3],
max_revision=2,
require_more_revision=0,
)
print(res)
res = reasoner.batch_abduce(
[["5", "+", "2"]], [None], [[5,10,9]], [65],
[["5", "+", "2"]],
[None],
[[5, 10, 9]],
[65],
max_revision=3,
require_more_revision=0,
)
print(res)
res = reasoner.batch_abduce(
[["5", "8", "8", "8", "8"]], [None], [[5,8,8,8,8]], [3.17],
[["5", "8", "8", "8", "8"]],
[None],
[[5, 8, 8, 8, 8]],
[3.17],
max_revision=5,
require_more_revision=3,
)
@@ -476,19 +569,28 @@ if __name__ == "__main__":
kb = HWF_KB(GKB_len_list=[1, 3, 5], max_err=0.1)
reasoner = ReasonerBase(kb, "hamming")
res = reasoner.batch_abduce(
[["5", "+", "2"]], [None], [[5,10,2]], [3],
[["5", "+", "2"]],
[None],
[[5, 10, 2]],
[3],
max_revision=2,
require_more_revision=0,
)
print(res)
res = reasoner.batch_abduce(
[["5", "+", "2"]], [None], [[5,10,9]], [65],
[["5", "+", "2"]],
[None],
[[5, 10, 9]],
[65],
max_revision=3,
require_more_revision=0,
)
print(res)
res = reasoner.batch_abduce(
[["5", "8", "8", "8", "8"]], [None], [[5,8,8,8,8]], [3.17],
[["5", "8", "8", "8", "8"]],
[None],
[[5, 8, 8, 8, 8]],
[3.17],
max_revision=5,
require_more_revision=3,
)
@@ -499,19 +601,28 @@ if __name__ == "__main__":
kb = HWF_KB(GKB_len_list=[1, 3, 5], prebuild_GKB=True, max_err=1)
reasoner = ReasonerBase(kb, "hamming")
res = reasoner.batch_abduce(
[["5", "+", "2"]], [None], [[5,10,2]], [3],
[["5", "+", "2"]],
[None],
[[5, 10, 2]],
[3],
max_revision=2,
require_more_revision=0,
)
print(res)
res = reasoner.batch_abduce(
[["5", "+", "2"]], [None], [[5,10,9]], [65],
[["5", "+", "2"]],
[None],
[[5, 10, 9]],
[65],
max_revision=3,
require_more_revision=0,
)
print(res)
res = reasoner.batch_abduce(
[["5", "8", "8", "8", "8"]], [None], [[5,8,8,8,8]], [3.17],
[["5", "8", "8", "8", "8"]],
[None],
[[5, 8, 8, 8, 8]],
[3.17],
max_revision=5,
require_more_revision=3,
)
@@ -522,19 +633,28 @@ if __name__ == "__main__":
kb = HWF_KB(GKB_len_list=[1, 3, 5], max_err=1)
reasoner = ReasonerBase(kb, "hamming")
res = reasoner.batch_abduce(
[["5", "+", "2"]], [None], [[5,10,2]], [3],
[["5", "+", "2"]],
[None],
[[5, 10, 2]],
[3],
max_revision=2,
require_more_revision=0,
)
print(res)
res = reasoner.batch_abduce(
[["5", "+", "2"]], [None], [[5,10,9]], [65],
[["5", "+", "2"]],
[None],
[[5, 10, 9]],
[65],
max_revision=3,
require_more_revision=0,
)
print(res)
res = reasoner.batch_abduce(
[["5", "8", "8", "8", "8"]], [None], [[5,8,8,8,8]], [3.17],
[["5", "8", "8", "8", "8"]],
[None],
[[5, 8, 8, 8, 8]],
[3.17],
max_revision=5,
require_more_revision=3,
)
@@ -545,21 +665,27 @@ if __name__ == "__main__":
kb = HWF_KB(GKB_len_list=[1, 3, 5], max_err=0.1)
reasoner = ReasonerBase(kb, "hamming")
res = reasoner.batch_abduce(
[["5", "+", "2"], ["5", "+", "9"]], [None, None], [[5,10,2],[5,10,9]],
[["5", "+", "2"], ["5", "+", "9"]],
[None, None],
[[5, 10, 2], [5, 10, 9]],
[3, 64],
max_revision=1,
require_more_revision=0,
)
print(res)
res = reasoner.batch_abduce(
[["5", "+", "2"], ["5", "+", "9"]], [None, None], [[5,10,2],[5,10,9]],
[["5", "+", "2"], ["5", "+", "9"]],
[None, None],
[[5, 10, 2], [5, 10, 9]],
[3, 64],
max_revision=3,
require_more_revision=0,
)
print(res)
res = reasoner.batch_abduce(
[["5", "+", "2"], ["5", "+", "9"]], [None, None], [[5,10,2],[5,10,9]],
[["5", "+", "2"], ["5", "+", "9"]],
[None, None],
[[5, 10, 2], [5, 10, 9]],
[3, 65],
max_revision=3,
require_more_revision=0,
@@ -568,14 +694,18 @@ if __name__ == "__main__":
print()
print("max_revision is float")
res = reasoner.batch_abduce(
[["5", "+", "2"], ["5", "+", "9"]], [None, None], [[5,10,2],[5,10,9]],
[["5", "+", "2"], ["5", "+", "9"]],
[None, None],
[[5, 10, 2], [5, 10, 9]],
[3, 64],
max_revision=0.5,
require_more_revision=0,
)
print(res)
res = reasoner.batch_abduce(
[["5", "+", "2"], ["5", "+", "9"]], [None, None], [[5,10,2],[5,10,9]],
[["5", "+", "2"], ["5", "+", "9"]],
[None, None],
[[5, 10, 2], [5, 10, 9]],
[3, 64],
max_revision=0.9,
require_more_revision=0,
@@ -629,7 +759,9 @@ if __name__ == "__main__":
for idx in range(-1, len(pred_res)):
if (not idx in idxs) and (idx >= 0):
idxs.append(idx)
candidate = self._revise_by_idxs(pred_res, y, all_revision_flag, idxs)
candidate = self._revise_by_idxs(
pred_res, y, all_revision_flag, idxs
)
if len(candidate) == 0:
if len(idxs) > 1:
idxs.pop()
@@ -640,7 +772,9 @@ if __name__ == "__main__":
removed = [i for i in lefted_idxs if i in max_candidate_idxs]
if found:
candidate_size.append(len(removed) + 1)
lefted_idxs = [i for i in lefted_idxs if i not in max_candidate_idxs]
lefted_idxs = [
i for i in lefted_idxs if i not in max_candidate_idxs
]
candidate_size.sort()
score = 0
import math
@@ -681,11 +815,17 @@ if __name__ == "__main__":
print()

print("HED_Reasoner abduce")
res = reasoner.abduce((consist_exs, [[[None]]] * len(consist_exs), [None] * len(consist_exs)))
res = reasoner.abduce(
(consist_exs, [[[None]]] * len(consist_exs), [None] * len(consist_exs))
)
print(res)
res = reasoner.abduce((inconsist_exs1, [[[None]]] * len(inconsist_exs1), [None] * len(inconsist_exs1)))
res = reasoner.abduce(
(inconsist_exs1, [[[None]]] * len(inconsist_exs1), [None] * len(inconsist_exs1))
)
print(res)
res = reasoner.abduce((inconsist_exs2, [[[None]]] * len(inconsist_exs2), [None] * len(inconsist_exs2)))
res = reasoner.abduce(
(inconsist_exs2, [[[None]]] * len(inconsist_exs2), [None] * len(inconsist_exs2))
)
print(res)
print()



Loading…
Cancel
Save