diff --git a/abl/reasoning/kb.py b/abl/reasoning/kb.py index 1cbd291..eee456c 100644 --- a/abl/reasoning/kb.py +++ b/abl/reasoning/kb.py @@ -5,7 +5,7 @@ import numpy as np from collections import defaultdict from itertools import product, combinations -from abl.utils.utils import flatten, reform_idx, hamming_dist, check_equal, to_hashable, hashable_to_list +from ..utils.utils import flatten, reform_idx, hamming_dist, check_equal, to_hashable, hashable_to_list from multiprocessing import Pool @@ -14,9 +14,9 @@ import pyswip class KBBase(ABC): """ - Base class for reasoner. + Base class for knowledge base. - Attributes + Parameters ---------- pseudo_label_list : list List of possible pseudo labels. @@ -30,10 +30,11 @@ class KBBase(ABC): Notes ----- - Users creating there own KB should inherit from this base class. For the inherited - subclass, it's mandatory to provide `pseudo_label_list` and override the `logic_forward` - function. After that, other operations (e.g. how to perform abductive reasoning) - will be automatically set up. + Users should inherit from this base class to build their own knowledge base. For the + user-build KB (an inherited subclass), it's only required for the user to provide the + `pseudo_label_list` and override the `logic_forward` function (specifying how to + perform logical reasoning). After that, other operations (e.g. how to perform abductive + reasoning) will be automatically set up. """ def __init__(self, pseudo_label_list, max_err=0, use_cache=True): if not isinstance(pseudo_label_list, list): @@ -44,6 +45,9 @@ class KBBase(ABC): @abstractmethod def logic_forward(self, pseudo_labels): + """ + How to perform logical reasoning. Users are required to provide this. + """ pass def abduce_candidates(self, pred_pseudo_label, y, max_revision_num, require_more_revision=0): @@ -55,7 +59,7 @@ class KBBase(ABC): pred_pseudo_label : List[Any] Predicted pseudo label. y : any - Ground truth. + Ground truth for the result (after passing through the logic part). max_revision_num : int The upper limit on the number of revisions. require_more_revision : int, optional @@ -85,7 +89,7 @@ class KBBase(ABC): pred_pseudo_label : List[Any] Predicted pseudo label. y : Any - Ground truth. + Ground truth for the result (after passing through the logic part). revision_idx : array-like Indices of where revisions should be made to the predicted pseudo label. """ @@ -122,8 +126,8 @@ class KBBase(ABC): ---------- pred_pseudo_label : List[Any] Predicted pseudo label. - y : any - Ground truth. + y : Any + Ground truth for the result (after passing through the logic part). max_revision_num : int The upper limit on the number of revisions. require_more_revision : int @@ -165,30 +169,58 @@ class KBBase(ABC): pred_pseudo_label = hashable_to_list(pred_pseudo_label) y = hashable_to_list(y) return self._abduce_by_search(pred_pseudo_label, y, max_revision_num, require_more_revision) - + + class ground_KB(KBBase): - def __init__(self, pseudo_label_list, GKB_len_list=None, max_err=0): + """ + Knowledge base with a ground KB (GKB). Ground KB is a knowledge base prebuilt + upon class initialization, stroing all potential candidates along with + their respective results after passing through the logic part. Ground KB can + enhance the speed of abductive reasoning. For more on this, refer to the + `abduce_candidates` method in this class. + + Parameters + ---------- + pseudo_label_list : list + Refer to class `KBBase`. + GKB_len_list : list + List of possible lengths of pseudo label. + max_err : float, optional + Refer to class `KBBase`. + + Notes + ----- + Users can also inherit from this class to build their own knowledge base. + Similar to `KBBase`, users are only required to provide the `pseudo_label_list` + and override the `logic_forward` function. Additionally, users should provide + the `GKB_len_list`. After that, other operations (e.g. auto-construction of + GKB, and how to perform abductive reasoning) will be automatically set up. + """ + def __init__(self, pseudo_label_list, GKB_len_list, max_err=0): super().__init__(pseudo_label_list, max_err) - + if not isinstance(GKB_len_list, list): + raise TypeError("GKB_len_list should be list") self.GKB_len_list = GKB_len_list - self.base = {} + self.GKB = {} X, Y = self._get_GKB() for x, y in zip(X, Y): - self.base.setdefault(len(x), defaultdict(list))[y].append(x) + self.GKB.setdefault(len(x), defaultdict(list))[y].append(x) - # For parallel version of _get_GKB + def _get_XY_list(self, args): pre_x, post_x_it = args[0], args[1] XY_list = [] for post_x in post_x_it: x = (pre_x,) + post_x y = self.logic_forward(x) - if y is not None: + if y is not np.inf: XY_list.append((x, y)) return XY_list - # Parallel _get_GKB def _get_GKB(self): + """ + Prebuild the GKB according to `pseudo_label_list` and `GKB_len_list`. + """ X, Y = [], [] for length in self.GKB_len_list: arg_list = [] @@ -208,13 +240,37 @@ class ground_KB(KBBase): return X, Y def abduce_candidates(self, pred_pseudo_label, y, max_revision_num, require_more_revision=0): - return self._abduce_by_GKB(pred_pseudo_label, y, max_revision_num, require_more_revision) + """ + Perform abductive reasoning by directly retrieving consistent candidates from + the prebuilt GKB. In this way, the time-consuming exhaustive search can be + avoided. + This is an overridden function. For more information about the parameters and + returns, refer to the function of the same name in class `KBBase`. + """ + if self.GKB == {} or len(pred_pseudo_label) not in self.GKB_len_list: + return [] + + all_candidates = self._find_candidate_GKB(pred_pseudo_label, y) + if len(all_candidates) == 0: + return [] + + cost_list = hamming_dist(pred_pseudo_label, all_candidates) + min_revision_num = np.min(cost_list) + revision_num = min(max_revision_num, min_revision_num + require_more_revision) + idxs = np.where(cost_list <= revision_num)[0] + candidates = [all_candidates[idx] for idx in idxs] + return candidates def _find_candidate_GKB(self, pred_pseudo_label, y): + """ + Retrieve consistent candidates from the prebuilt GKB. If `max_err` is greater + than 0, return all candidates whose logical results fall within the + [y - max_err, y + max_err] range. + """ if self.max_err == 0: - return self.base[len(pred_pseudo_label)][y] + return self.GKB[len(pred_pseudo_label)][y] else: - potential_candidates = self.base[len(pred_pseudo_label)] + potential_candidates = self.GKB[len(pred_pseudo_label)] key_list = list(potential_candidates.keys()) key_idx = bisect.bisect_left(key_list, y) @@ -233,35 +289,7 @@ class ground_KB(KBBase): else: break return all_candidates - - def _abduce_by_GKB(self, pred_pseudo_label, y, max_revision_num, require_more_revision): - if self.base == {} or len(pred_pseudo_label) not in self.GKB_len_list: - return [] - - all_candidates = self._find_candidate_GKB(pred_pseudo_label, y) - if len(all_candidates) == 0: - return [] - cost_list = hamming_dist(pred_pseudo_label, all_candidates) - min_revision_num = np.min(cost_list) - revision_num = min(max_revision_num, min_revision_num + require_more_revision) - idxs = np.where(cost_list <= revision_num)[0] - candidates = [all_candidates[idx] for idx in idxs] - return candidates - - def _dict_len(self, dic): - if not self.GKB_flag: - return 0 - else: - return sum(len(c) for c in dic.values()) - - def __len__(self): - if not self.GKB_flag: - return 0 - else: - return sum(self._dict_len(v) for v in self.base.values()) - - class prolog_KB(KBBase): def __init__(self, pseudo_label_list, pl_file, max_err=0): diff --git a/abl/reasoning/reasoner.py b/abl/reasoning/reasoner.py index d866a99..e853b0e 100644 --- a/abl/reasoning/reasoner.py +++ b/abl/reasoning/reasoner.py @@ -1,6 +1,6 @@ import numpy as np from zoopt import Dimension, Objective, Parameter, Opt -from abl.utils.utils import ( +from ..utils.utils import ( confidence_dist, flatten, reform_idx, @@ -13,7 +13,7 @@ class ReasonerBase: """ Base class for reasoner. - Attributes + Parameters ---------- kb : The knowledge base to be used for reasoning. @@ -115,7 +115,7 @@ class ReasonerBase: Predicted probabilities of the prediction (Each sublist contains the probability distribution over all pseudo labels). y : Any - Ground truth. + Ground truth for the result (after passing through the logic part). max_revision_num : int Specifies the maximum number of revisions allowed. """ @@ -162,7 +162,7 @@ class ReasonerBase: pred_pseudo_label : List[Any] Predicted pseudo label. y : Any - Ground truth. + Ground truth for the result (after passing through the logic part). revision_idx : array-like Indices of where revisions should be made to the predicted pseudo label. """ @@ -181,8 +181,8 @@ class ReasonerBase: distribution over all pseudo labels). pred_pseudo_label : List[Any] Predicted pseudo label. - y : any - Ground truth. + y : Any + Ground truth for the result (after passing through the logic part). max_revision : int or float, optional The upper limit on the number of revisions. If float, denotes the fraction of the total length that can be revised. A value of -1 implies no restriction on the number @@ -456,7 +456,7 @@ if __name__ == "__main__": print() print("HWF_KB with GKB, max_err=0.1") - kb = HWF_ground_KB(GKB_len_list=[1, 3, 5], max_err=0.1) + kb = HWF_ground_KB(GKB_len_list=[1, 3, 5, 7], max_err=0.1) reasoner = ReasonerBase(kb, "hamming") test_hwf(reasoner)