From 01f00d225e52a57e5af4db914e03eaa11ff071df Mon Sep 17 00:00:00 2001 From: troyyyyy Date: Mon, 6 Nov 2023 09:34:02 +0800 Subject: [PATCH] [MNT] add docstring for class prolog_KB --- abl/reasoning/kb.py | 93 ++++++++++++++++++++++++--------------- abl/reasoning/reasoner.py | 14 +++--- 2 files changed, 65 insertions(+), 42 deletions(-) diff --git a/abl/reasoning/kb.py b/abl/reasoning/kb.py index eee456c..cc34440 100644 --- a/abl/reasoning/kb.py +++ b/abl/reasoning/kb.py @@ -21,9 +21,9 @@ class KBBase(ABC): pseudo_label_list : list List of possible pseudo labels. max_err : float, optional - The upper tolerance limit when comparing the similarity between a candidate result - and the ground truth. Especially relevant for regression problems where exact matches - might not be feasible. Default to 0. + The upper tolerance limit when comparing the similarity between a candidate's logical + result and the ground truth. Especially relevant for regression problems where exact + matches might not be feasible. Default to 0. use_cache : bool, optional Whether to use a cache for previously abduced candidates to speed up subsequent operations. Defaults to True. @@ -46,7 +46,8 @@ class KBBase(ABC): @abstractmethod def logic_forward(self, pseudo_labels): """ - How to perform logical reasoning. Users are required to provide this. + How to perform (deductive) logical reasoning, i.e. matching each pseudo label to + their logical result. Users are required to provide this. """ pass @@ -59,7 +60,7 @@ class KBBase(ABC): pred_pseudo_label : List[Any] Predicted pseudo label. y : any - Ground truth for the result (after passing through the logic part). + Ground truth for the logical result. max_revision_num : int The upper limit on the number of revisions. require_more_revision : int, optional @@ -89,7 +90,7 @@ class KBBase(ABC): pred_pseudo_label : List[Any] Predicted pseudo label. y : Any - Ground truth for the result (after passing through the logic part). + Ground truth for the logical result. revision_idx : array-like Indices of where revisions should be made to the predicted pseudo label. """ @@ -127,7 +128,7 @@ class KBBase(ABC): pred_pseudo_label : List[Any] Predicted pseudo label. y : Any - Ground truth for the result (after passing through the logic part). + Ground truth for the logical result. max_revision_num : int The upper limit on the number of revisions. require_more_revision : int @@ -173,11 +174,9 @@ class KBBase(ABC): class ground_KB(KBBase): """ - Knowledge base with a ground KB (GKB). Ground KB is a knowledge base prebuilt - upon class initialization, stroing all potential candidates along with - their respective results after passing through the logic part. Ground KB can - enhance the speed of abductive reasoning. For more on this, refer to the - `abduce_candidates` method in this class. + Knowledge base with a ground KB (GKB). Ground KB is a knowledge base prebuilt upon + class initialization, stroing all potential candidates along with their respective + logical result. Ground KB can accelerate abductive reasoning in `abduce_candidates`. Parameters ---------- @@ -190,11 +189,11 @@ class ground_KB(KBBase): Notes ----- - Users can also inherit from this class to build their own knowledge base. - Similar to `KBBase`, users are only required to provide the `pseudo_label_list` - and override the `logic_forward` function. Additionally, users should provide - the `GKB_len_list`. After that, other operations (e.g. auto-construction of - GKB, and how to perform abductive reasoning) will be automatically set up. + Users can also inherit from this class to build their own knowledge base. Similar + to `KBBase`, users are only required to provide the `pseudo_label_list` and override + the `logic_forward` function. Additionally, users should provide the `GKB_len_list`. + After that, other operations (e.g. auto-construction of GKB, and how to perform + abductive reasoning) will be automatically set up. """ def __init__(self, pseudo_label_list, GKB_len_list, max_err=0): super().__init__(pseudo_label_list, max_err) @@ -272,32 +271,46 @@ class ground_KB(KBBase): else: potential_candidates = self.GKB[len(pred_pseudo_label)] key_list = list(potential_candidates.keys()) - key_idx = bisect.bisect_left(key_list, y) - all_candidates = [] - for idx in range(key_idx - 1, 0, -1): - k = key_list[idx] - if abs(k - y) <= self.max_err: - all_candidates.extend(potential_candidates[k]) - else: - break - - for idx in range(key_idx, len(key_list)): - k = key_list[idx] - if abs(k - y) <= self.max_err: - all_candidates.extend(potential_candidates[k]) - else: - break + low_key = bisect.bisect_left(key_list, y - self.max_err) + high_key = bisect.bisect_right(key_list, y + self.max_err) + + all_candidates = [candidate + for key in key_list[low_key:high_key] + for candidate in potential_candidates[key]] return all_candidates class prolog_KB(KBBase): - def __init__(self, pseudo_label_list, pl_file, max_err=0): - super().__init__(pseudo_label_list, max_err) + """ + Knowledge base given by a prolog (pl) file. + + Parameters + ---------- + pseudo_label_list : list + Refer to class `KBBase`. + pl_file : + Prolog file containing the KB. + max_err : float, optional + Refer to class `KBBase`. + + Notes + ----- + Users can also inherit from this class to build their own knowledge base. When using + this class, users are only required to provide the `pl_file`. + """ + def __init__(self, pseudo_label_list, pl_file): + super().__init__(pseudo_label_list) self.prolog = pyswip.Prolog() self.prolog.consult(pl_file) def logic_forward(self, pseudo_labels): + """ + Consult prolog with the query `logic_forward(pseudo_labels, Res).`, and set the + returned `Res` as the logical results. To use this default function, there must be + a Prolog `log_forward` method in the pl file to perform logical. reasoning. Otherwise, + users would override this function. + """ result = list(self.prolog.query("logic_forward(%s, Res)." % pseudo_labels))[0]['Res'] if result == 'true': return True @@ -314,11 +327,16 @@ class prolog_KB(KBBase): revision_pred_pseudo_label[idx] = 'P' + str(idx) revision_pred_pseudo_label = reform_idx(revision_pred_pseudo_label, pred_pseudo_label) - # TODO:不知道有没有更简洁的方法 regex = r"'P\d+'" return re.sub(regex, lambda x: x.group().replace("'", ""), str(revision_pred_pseudo_label)) def get_query_string(self, pred_pseudo_label, y, revision_idx): + """ + Consult prolog with `logic_forward([kept_labels, Revise_labels], Res).`, and set + the returned `Revise_labels` together with the kept labels as the candidates. This is + a default fuction for demo, users would override this function to adapt to their own + Prolog file. + """ query_string = "logic_forward(" query_string += self._revision_pred_pseudo_label(pred_pseudo_label, revision_idx) key_is_none_flag = y is None or (type(y) == list and y[0] is None) @@ -326,6 +344,11 @@ class prolog_KB(KBBase): return query_string def revise_at_idx(self, pred_pseudo_label, y, revision_idx): + """ + Revise the predicted pseudo label at specified index positions by querying Prolog. + This is an overridden function. For more information about the parameters, refer to + the function of the same name in class `KBBase`. + """ candidates = [] query_string = self.get_query_string(pred_pseudo_label, y, revision_idx) save_pred_pseudo_label = pred_pseudo_label diff --git a/abl/reasoning/reasoner.py b/abl/reasoning/reasoner.py index e853b0e..65879c7 100644 --- a/abl/reasoning/reasoner.py +++ b/abl/reasoning/reasoner.py @@ -1,6 +1,6 @@ import numpy as np from zoopt import Dimension, Objective, Parameter, Opt -from ..utils.utils import ( +from abl.utils.utils import ( confidence_dist, flatten, reform_idx, @@ -60,7 +60,7 @@ class ReasonerBase: Predicted probabilities of the prediction (Each sublist contains the probability distribution over all pseudo labels). candidates : List[List[Any]] - Multiple candidate abduction results. + Multiple consistent candidates. """ if len(candidates) == 0: return [] @@ -88,7 +88,7 @@ class ReasonerBase: Predicted probabilities of the prediction (Each sublist contains the probability distribution over all pseudo labels). Used when distance function is "confidence". candidates : List[List[Any]] - Multiple candidate abduction results. + Multiple consistent candidates. """ if self.dist_func == "hamming": return hamming_dist(pred_pseudo_label, candidates) @@ -115,7 +115,7 @@ class ReasonerBase: Predicted probabilities of the prediction (Each sublist contains the probability distribution over all pseudo labels). y : Any - Ground truth for the result (after passing through the logic part). + Ground truth for the logical result. max_revision_num : int Specifies the maximum number of revisions allowed. """ @@ -162,7 +162,7 @@ class ReasonerBase: pred_pseudo_label : List[Any] Predicted pseudo label. y : Any - Ground truth for the result (after passing through the logic part). + Ground truth for the logical result. revision_idx : array-like Indices of where revisions should be made to the predicted pseudo label. """ @@ -182,7 +182,7 @@ class ReasonerBase: pred_pseudo_label : List[Any] Predicted pseudo label. y : Any - Ground truth for the result (after passing through the logic part). + Ground truth for the logical result. max_revision : int or float, optional The upper limit on the number of revisions. If float, denotes the fraction of the total length that can be revised. A value of -1 implies no restriction on the number @@ -456,7 +456,7 @@ if __name__ == "__main__": print() print("HWF_KB with GKB, max_err=0.1") - kb = HWF_ground_KB(GKB_len_list=[1, 3, 5, 7], max_err=0.1) + kb = HWF_ground_KB(GKB_len_list=[1, 3, 5], max_err=0.1) reasoner = ReasonerBase(kb, "hamming") test_hwf(reasoner)