|
|
|
@@ -5,7 +5,7 @@ import numpy as np |
|
|
|
from collections import defaultdict |
|
|
|
from itertools import product, combinations |
|
|
|
|
|
|
|
from abl.utils.utils import flatten, reform_idx, hamming_dist, check_equal, to_hashable, hashable_to_list |
|
|
|
from ..utils.utils import flatten, reform_idx, hamming_dist, check_equal, to_hashable, hashable_to_list |
|
|
|
|
|
|
|
from multiprocessing import Pool |
|
|
|
|
|
|
|
@@ -14,9 +14,9 @@ import pyswip |
|
|
|
|
|
|
|
class KBBase(ABC): |
|
|
|
""" |
|
|
|
Base class for reasoner. |
|
|
|
Base class for knowledge base. |
|
|
|
|
|
|
|
Attributes |
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
pseudo_label_list : list |
|
|
|
List of possible pseudo labels. |
|
|
|
@@ -30,10 +30,11 @@ class KBBase(ABC): |
|
|
|
|
|
|
|
Notes |
|
|
|
----- |
|
|
|
Users creating there own KB should inherit from this base class. For the inherited |
|
|
|
subclass, it's mandatory to provide `pseudo_label_list` and override the `logic_forward` |
|
|
|
function. After that, other operations (e.g. how to perform abductive reasoning) |
|
|
|
will be automatically set up. |
|
|
|
Users should inherit from this base class to build their own knowledge base. For the |
|
|
|
user-build KB (an inherited subclass), it's only required for the user to provide the |
|
|
|
`pseudo_label_list` and override the `logic_forward` function (specifying how to |
|
|
|
perform logical reasoning). After that, other operations (e.g. how to perform abductive |
|
|
|
reasoning) will be automatically set up. |
|
|
|
""" |
|
|
|
def __init__(self, pseudo_label_list, max_err=0, use_cache=True): |
|
|
|
if not isinstance(pseudo_label_list, list): |
|
|
|
@@ -44,6 +45,9 @@ class KBBase(ABC): |
|
|
|
|
|
|
|
@abstractmethod |
|
|
|
def logic_forward(self, pseudo_labels): |
|
|
|
""" |
|
|
|
How to perform logical reasoning. Users are required to provide this. |
|
|
|
""" |
|
|
|
pass |
|
|
|
|
|
|
|
def abduce_candidates(self, pred_pseudo_label, y, max_revision_num, require_more_revision=0): |
|
|
|
@@ -55,7 +59,7 @@ class KBBase(ABC): |
|
|
|
pred_pseudo_label : List[Any] |
|
|
|
Predicted pseudo label. |
|
|
|
y : any |
|
|
|
Ground truth. |
|
|
|
Ground truth for the result (after passing through the logic part). |
|
|
|
max_revision_num : int |
|
|
|
The upper limit on the number of revisions. |
|
|
|
require_more_revision : int, optional |
|
|
|
@@ -85,7 +89,7 @@ class KBBase(ABC): |
|
|
|
pred_pseudo_label : List[Any] |
|
|
|
Predicted pseudo label. |
|
|
|
y : Any |
|
|
|
Ground truth. |
|
|
|
Ground truth for the result (after passing through the logic part). |
|
|
|
revision_idx : array-like |
|
|
|
Indices of where revisions should be made to the predicted pseudo label. |
|
|
|
""" |
|
|
|
@@ -122,8 +126,8 @@ class KBBase(ABC): |
|
|
|
---------- |
|
|
|
pred_pseudo_label : List[Any] |
|
|
|
Predicted pseudo label. |
|
|
|
y : any |
|
|
|
Ground truth. |
|
|
|
y : Any |
|
|
|
Ground truth for the result (after passing through the logic part). |
|
|
|
max_revision_num : int |
|
|
|
The upper limit on the number of revisions. |
|
|
|
require_more_revision : int |
|
|
|
@@ -165,30 +169,58 @@ class KBBase(ABC): |
|
|
|
pred_pseudo_label = hashable_to_list(pred_pseudo_label) |
|
|
|
y = hashable_to_list(y) |
|
|
|
return self._abduce_by_search(pred_pseudo_label, y, max_revision_num, require_more_revision) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ground_KB(KBBase): |
|
|
|
def __init__(self, pseudo_label_list, GKB_len_list=None, max_err=0): |
|
|
|
""" |
|
|
|
Knowledge base with a ground KB (GKB). Ground KB is a knowledge base prebuilt |
|
|
|
upon class initialization, stroing all potential candidates along with |
|
|
|
their respective results after passing through the logic part. Ground KB can |
|
|
|
enhance the speed of abductive reasoning. For more on this, refer to the |
|
|
|
`abduce_candidates` method in this class. |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
pseudo_label_list : list |
|
|
|
Refer to class `KBBase`. |
|
|
|
GKB_len_list : list |
|
|
|
List of possible lengths of pseudo label. |
|
|
|
max_err : float, optional |
|
|
|
Refer to class `KBBase`. |
|
|
|
|
|
|
|
Notes |
|
|
|
----- |
|
|
|
Users can also inherit from this class to build their own knowledge base. |
|
|
|
Similar to `KBBase`, users are only required to provide the `pseudo_label_list` |
|
|
|
and override the `logic_forward` function. Additionally, users should provide |
|
|
|
the `GKB_len_list`. After that, other operations (e.g. auto-construction of |
|
|
|
GKB, and how to perform abductive reasoning) will be automatically set up. |
|
|
|
""" |
|
|
|
def __init__(self, pseudo_label_list, GKB_len_list, max_err=0): |
|
|
|
super().__init__(pseudo_label_list, max_err) |
|
|
|
|
|
|
|
if not isinstance(GKB_len_list, list): |
|
|
|
raise TypeError("GKB_len_list should be list") |
|
|
|
self.GKB_len_list = GKB_len_list |
|
|
|
self.base = {} |
|
|
|
self.GKB = {} |
|
|
|
X, Y = self._get_GKB() |
|
|
|
for x, y in zip(X, Y): |
|
|
|
self.base.setdefault(len(x), defaultdict(list))[y].append(x) |
|
|
|
self.GKB.setdefault(len(x), defaultdict(list))[y].append(x) |
|
|
|
|
|
|
|
# For parallel version of _get_GKB |
|
|
|
|
|
|
|
def _get_XY_list(self, args): |
|
|
|
pre_x, post_x_it = args[0], args[1] |
|
|
|
XY_list = [] |
|
|
|
for post_x in post_x_it: |
|
|
|
x = (pre_x,) + post_x |
|
|
|
y = self.logic_forward(x) |
|
|
|
if y is not None: |
|
|
|
if y is not np.inf: |
|
|
|
XY_list.append((x, y)) |
|
|
|
return XY_list |
|
|
|
|
|
|
|
# Parallel _get_GKB |
|
|
|
def _get_GKB(self): |
|
|
|
""" |
|
|
|
Prebuild the GKB according to `pseudo_label_list` and `GKB_len_list`. |
|
|
|
""" |
|
|
|
X, Y = [], [] |
|
|
|
for length in self.GKB_len_list: |
|
|
|
arg_list = [] |
|
|
|
@@ -208,13 +240,37 @@ class ground_KB(KBBase): |
|
|
|
return X, Y |
|
|
|
|
|
|
|
def abduce_candidates(self, pred_pseudo_label, y, max_revision_num, require_more_revision=0): |
|
|
|
return self._abduce_by_GKB(pred_pseudo_label, y, max_revision_num, require_more_revision) |
|
|
|
""" |
|
|
|
Perform abductive reasoning by directly retrieving consistent candidates from |
|
|
|
the prebuilt GKB. In this way, the time-consuming exhaustive search can be |
|
|
|
avoided. |
|
|
|
This is an overridden function. For more information about the parameters and |
|
|
|
returns, refer to the function of the same name in class `KBBase`. |
|
|
|
""" |
|
|
|
if self.GKB == {} or len(pred_pseudo_label) not in self.GKB_len_list: |
|
|
|
return [] |
|
|
|
|
|
|
|
all_candidates = self._find_candidate_GKB(pred_pseudo_label, y) |
|
|
|
if len(all_candidates) == 0: |
|
|
|
return [] |
|
|
|
|
|
|
|
cost_list = hamming_dist(pred_pseudo_label, all_candidates) |
|
|
|
min_revision_num = np.min(cost_list) |
|
|
|
revision_num = min(max_revision_num, min_revision_num + require_more_revision) |
|
|
|
idxs = np.where(cost_list <= revision_num)[0] |
|
|
|
candidates = [all_candidates[idx] for idx in idxs] |
|
|
|
return candidates |
|
|
|
|
|
|
|
def _find_candidate_GKB(self, pred_pseudo_label, y): |
|
|
|
""" |
|
|
|
Retrieve consistent candidates from the prebuilt GKB. If `max_err` is greater |
|
|
|
than 0, return all candidates whose logical results fall within the |
|
|
|
[y - max_err, y + max_err] range. |
|
|
|
""" |
|
|
|
if self.max_err == 0: |
|
|
|
return self.base[len(pred_pseudo_label)][y] |
|
|
|
return self.GKB[len(pred_pseudo_label)][y] |
|
|
|
else: |
|
|
|
potential_candidates = self.base[len(pred_pseudo_label)] |
|
|
|
potential_candidates = self.GKB[len(pred_pseudo_label)] |
|
|
|
key_list = list(potential_candidates.keys()) |
|
|
|
key_idx = bisect.bisect_left(key_list, y) |
|
|
|
|
|
|
|
@@ -233,35 +289,7 @@ class ground_KB(KBBase): |
|
|
|
else: |
|
|
|
break |
|
|
|
return all_candidates |
|
|
|
|
|
|
|
def _abduce_by_GKB(self, pred_pseudo_label, y, max_revision_num, require_more_revision): |
|
|
|
if self.base == {} or len(pred_pseudo_label) not in self.GKB_len_list: |
|
|
|
return [] |
|
|
|
|
|
|
|
all_candidates = self._find_candidate_GKB(pred_pseudo_label, y) |
|
|
|
if len(all_candidates) == 0: |
|
|
|
return [] |
|
|
|
|
|
|
|
cost_list = hamming_dist(pred_pseudo_label, all_candidates) |
|
|
|
min_revision_num = np.min(cost_list) |
|
|
|
revision_num = min(max_revision_num, min_revision_num + require_more_revision) |
|
|
|
idxs = np.where(cost_list <= revision_num)[0] |
|
|
|
candidates = [all_candidates[idx] for idx in idxs] |
|
|
|
return candidates |
|
|
|
|
|
|
|
def _dict_len(self, dic): |
|
|
|
if not self.GKB_flag: |
|
|
|
return 0 |
|
|
|
else: |
|
|
|
return sum(len(c) for c in dic.values()) |
|
|
|
|
|
|
|
def __len__(self): |
|
|
|
if not self.GKB_flag: |
|
|
|
return 0 |
|
|
|
else: |
|
|
|
return sum(self._dict_len(v) for v in self.base.values()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class prolog_KB(KBBase): |
|
|
|
def __init__(self, pseudo_label_list, pl_file, max_err=0): |
|
|
|
|