From 233fc9738d2c92c18bf9ec7dbaa8bd9c951b5519 Mon Sep 17 00:00:00 2001 From: troyyyyy Date: Wed, 1 Nov 2023 20:52:49 +0800 Subject: [PATCH] [MNT] add docstring for class KBBase --- abl/reasoning/kb.py | 204 +++++++++++++++++++++++++++----------- abl/reasoning/reasoner.py | 17 ++-- 2 files changed, 156 insertions(+), 65 deletions(-) diff --git a/abl/reasoning/kb.py b/abl/reasoning/kb.py index b700aa1..1cbd291 100644 --- a/abl/reasoning/kb.py +++ b/abl/reasoning/kb.py @@ -5,7 +5,7 @@ import numpy as np from collections import defaultdict from itertools import product, combinations -from ..utils.utils import flatten, reform_idx, hamming_dist, check_equal, to_hashable, hashable_to_list +from abl.utils.utils import flatten, reform_idx, hamming_dist, check_equal, to_hashable, hashable_to_list from multiprocessing import Pool @@ -13,11 +13,31 @@ from functools import lru_cache import pyswip class KBBase(ABC): - def __init__(self, pseudo_label_list, max_err=0, use_cache=True): - # TODO:添加一下类型检查,比如 - # if not isinstance(X, (np.ndarray, spmatrix)): - # raise TypeError("X should be numpy array or sparse matrix") + """ + Base class for reasoner. + Attributes + ---------- + pseudo_label_list : list + List of possible pseudo labels. + max_err : float, optional + The upper tolerance limit when comparing the similarity between a candidate result + and the ground truth. Especially relevant for regression problems where exact matches + might not be feasible. Default to 0. + use_cache : bool, optional + Whether to use a cache for previously abduced candidates to speed up subsequent + operations. Defaults to True. + + Notes + ----- + Users creating there own KB should inherit from this base class. For the inherited + subclass, it's mandatory to provide `pseudo_label_list` and override the `logic_forward` + function. After that, other operations (e.g. how to perform abductive reasoning) + will be automatically set up. + """ + def __init__(self, pseudo_label_list, max_err=0, use_cache=True): + if not isinstance(pseudo_label_list, list): + raise TypeError("pseudo_label_list should be list") self.pseudo_label_list = pseudo_label_list self.max_err = max_err self.use_cache = use_cache @@ -26,39 +46,105 @@ class KBBase(ABC): def logic_forward(self, pseudo_labels): pass - def abduce_candidates(self, pred_res, y, max_revision_num, require_more_revision=0): + def abduce_candidates(self, pred_pseudo_label, y, max_revision_num, require_more_revision=0): + """ + Perform abductive reasoning to get a candidate consistent with the knowledge base. + + Parameters + ---------- + pred_pseudo_label : List[Any] + Predicted pseudo label. + y : any + Ground truth. + max_revision_num : int + The upper limit on the number of revisions. + require_more_revision : int, optional + Specifies additional number of revisions permitted beyond the minimum required. + Defaults to 0. + + Returns + ------- + List[List[Any]] + A list of candidates, i.e. revised pseudo label that are consistent with the + knowledge base. + """ if not self.use_cache: - return self._abduce_by_search(pred_res, y, max_revision_num, require_more_revision) + return self._abduce_by_search(pred_pseudo_label, y, + max_revision_num, require_more_revision) else: - return self._abduce_by_search_cache(to_hashable(pred_res), to_hashable(y), max_revision_num, require_more_revision) + return self._abduce_by_search_cache(to_hashable(pred_pseudo_label), + to_hashable(y), + max_revision_num, require_more_revision) - def revise_by_idx(self, pred_res, y, revision_idx): + def revise_at_idx(self, pred_pseudo_label, y, revision_idx): + """ + Revise the predicted pseudo label at specified index positions. + + Parameters + ---------- + pred_pseudo_label : List[Any] + Predicted pseudo label. + y : Any + Ground truth. + revision_idx : array-like + Indices of where revisions should be made to the predicted pseudo label. + """ candidates = [] abduce_c = product(self.pseudo_label_list, repeat=len(revision_idx)) for c in abduce_c: - candidate = pred_res.copy() + candidate = pred_pseudo_label.copy() for i, idx in enumerate(revision_idx): candidate[idx] = c[i] if check_equal(self.logic_forward(candidate), y, self.max_err): candidates.append(candidate) return candidates - def _revision(self, revision_num, pred_res, y): + def _revision(self, revision_num, pred_pseudo_label, y): + """ + For a specified number of pseudo label to revise, iterate through all possible + indices to find any candidates that are consistent with the knowledge base. + """ new_candidates = [] - revision_idx_list = combinations(range(len(pred_res)), revision_num) + revision_idx_list = combinations(range(len(pred_pseudo_label)), revision_num) for revision_idx in revision_idx_list: - candidates = self.revise_by_idx(pred_res, y, revision_idx) + candidates = self.revise_at_idx(pred_pseudo_label, y, revision_idx) new_candidates.extend(candidates) return new_candidates - def _abduce_by_search(self, pred_res, y, max_revision_num, require_more_revision): + def _abduce_by_search(self, pred_pseudo_label, y, max_revision_num, require_more_revision): + """ + Perform abductive reasoning by exhastive search. Specifically, begin with 0 and + continuously increase the number of pseudo labels to revise, until candidates + that are consistent with the knowledge base are found. + + Parameters + ---------- + pred_pseudo_label : List[Any] + Predicted pseudo label. + y : any + Ground truth. + max_revision_num : int + The upper limit on the number of revisions. + require_more_revision : int + If larger than 0, then after having found any candidates consistent with the + knowledge base, continue to increase the number pseudo labels to revise to + get more possible consistent candidates. + + Returns + ------- + List[List[Any]] + A list of candidates, i.e. revised pseudo label that are consistent with the + knowledge base. + """ candidates = [] - for revision_num in range(len(pred_res) + 1): - if revision_num == 0 and check_equal(self.logic_forward(pred_res), y, self.max_err): - candidates.append(pred_res) + for revision_num in range(len(pred_pseudo_label) + 1): + if revision_num == 0 and check_equal(self.logic_forward(pred_pseudo_label), + y, + self.max_err): + candidates.append(pred_pseudo_label) elif revision_num > 0: - candidates.extend(self._revision(revision_num, pred_res, y)) + candidates.extend(self._revision(revision_num, pred_pseudo_label, y)) if len(candidates) > 0: min_revision_num = revision_num break @@ -68,26 +154,17 @@ class KBBase(ABC): for revision_num in range(min_revision_num + 1, min_revision_num + require_more_revision + 1): if revision_num > max_revision_num: return candidates - candidates.extend(self._revision(revision_num, pred_res, y)) + candidates.extend(self._revision(revision_num, pred_pseudo_label, y)) return candidates @lru_cache(maxsize=None) - def _abduce_by_search_cache(self, pred_res, y, max_revision_num, require_more_revision): - pred_res = hashable_to_list(pred_res) + def _abduce_by_search_cache(self, pred_pseudo_label, y, max_revision_num, require_more_revision): + """ + `_abduce_by_search` with cache. + """ + pred_pseudo_label = hashable_to_list(pred_pseudo_label) y = hashable_to_list(y) - return self._abduce_by_search(pred_res, y, max_revision_num, require_more_revision) - - def _dict_len(self, dic): - if not self.GKB_flag: - return 0 - else: - return sum(len(c) for c in dic.values()) - - def __len__(self): - if not self.GKB_flag: - return 0 - else: - return sum(self._dict_len(v) for v in self.base.values()) + return self._abduce_by_search(pred_pseudo_label, y, max_revision_num, require_more_revision) class ground_KB(KBBase): def __init__(self, pseudo_label_list, GKB_len_list=None, max_err=0): @@ -130,14 +207,14 @@ class ground_KB(KBBase): X, Y = zip(*sorted(zip(X, Y), key=lambda pair: pair[1])) return X, Y - def abduce_candidates(self, pred_res, y, max_revision_num, require_more_revision=0): - return self._abduce_by_GKB(pred_res, y, max_revision_num, require_more_revision) + def abduce_candidates(self, pred_pseudo_label, y, max_revision_num, require_more_revision=0): + return self._abduce_by_GKB(pred_pseudo_label, y, max_revision_num, require_more_revision) - def _find_candidate_GKB(self, pred_res, y): + def _find_candidate_GKB(self, pred_pseudo_label, y): if self.max_err == 0: - return self.base[len(pred_res)][y] + return self.base[len(pred_pseudo_label)][y] else: - potential_candidates = self.base[len(pred_res)] + potential_candidates = self.base[len(pred_pseudo_label)] key_list = list(potential_candidates.keys()) key_idx = bisect.bisect_left(key_list, y) @@ -157,21 +234,34 @@ class ground_KB(KBBase): break return all_candidates - def _abduce_by_GKB(self, pred_res, y, max_revision_num, require_more_revision): - if self.base == {} or len(pred_res) not in self.GKB_len_list: + def _abduce_by_GKB(self, pred_pseudo_label, y, max_revision_num, require_more_revision): + if self.base == {} or len(pred_pseudo_label) not in self.GKB_len_list: return [] - all_candidates = self._find_candidate_GKB(pred_res, y) + all_candidates = self._find_candidate_GKB(pred_pseudo_label, y) if len(all_candidates) == 0: return [] - cost_list = hamming_dist(pred_res, all_candidates) + cost_list = hamming_dist(pred_pseudo_label, all_candidates) min_revision_num = np.min(cost_list) revision_num = min(max_revision_num, min_revision_num + require_more_revision) idxs = np.where(cost_list <= revision_num)[0] candidates = [all_candidates[idx] for idx in idxs] return candidates + def _dict_len(self, dic): + if not self.GKB_flag: + return 0 + else: + return sum(len(c) for c in dic.values()) + + def __len__(self): + if not self.GKB_flag: + return 0 + else: + return sum(self._dict_len(v) for v in self.base.values()) + + class prolog_KB(KBBase): def __init__(self, pseudo_label_list, pl_file, max_err=0): @@ -187,36 +277,36 @@ class prolog_KB(KBBase): return False return result - def _revision_pred_res(self, pred_res, revision_idx): + def _revision_pred_pseudo_label(self, pred_pseudo_label, revision_idx): import re - revision_pred_res = pred_res.copy() - revision_pred_res = flatten(revision_pred_res) + revision_pred_pseudo_label = pred_pseudo_label.copy() + revision_pred_pseudo_label = flatten(revision_pred_pseudo_label) for idx in revision_idx: - revision_pred_res[idx] = 'P' + str(idx) - revision_pred_res = reform_idx(revision_pred_res, pred_res) + revision_pred_pseudo_label[idx] = 'P' + str(idx) + revision_pred_pseudo_label = reform_idx(revision_pred_pseudo_label, pred_pseudo_label) # TODO:不知道有没有更简洁的方法 regex = r"'P\d+'" - return re.sub(regex, lambda x: x.group().replace("'", ""), str(revision_pred_res)) + return re.sub(regex, lambda x: x.group().replace("'", ""), str(revision_pred_pseudo_label)) - def get_query_string(self, pred_res, y, revision_idx): + def get_query_string(self, pred_pseudo_label, y, revision_idx): query_string = "logic_forward(" - query_string += self._revision_pred_res(pred_res, revision_idx) + query_string += self._revision_pred_pseudo_label(pred_pseudo_label, revision_idx) key_is_none_flag = y is None or (type(y) == list and y[0] is None) query_string += ",%s)." % y if not key_is_none_flag else ")." return query_string - def revise_by_idx(self, pred_res, y, revision_idx): + def revise_at_idx(self, pred_pseudo_label, y, revision_idx): candidates = [] - query_string = self.get_query_string(pred_res, y, revision_idx) - save_pred_res = pred_res - pred_res = flatten(pred_res) + query_string = self.get_query_string(pred_pseudo_label, y, revision_idx) + save_pred_pseudo_label = pred_pseudo_label + pred_pseudo_label = flatten(pred_pseudo_label) abduce_c = [list(z.values()) for z in self.prolog.query(query_string)] for c in abduce_c: - candidate = pred_res.copy() + candidate = pred_pseudo_label.copy() for i, idx in enumerate(revision_idx): candidate[idx] = c[i] - candidate = reform_idx(candidate, save_pred_res) + candidate = reform_idx(candidate, save_pred_pseudo_label) candidates.append(candidate) return candidates diff --git a/abl/reasoning/reasoner.py b/abl/reasoning/reasoner.py index e221792..d866a99 100644 --- a/abl/reasoning/reasoner.py +++ b/abl/reasoning/reasoner.py @@ -1,6 +1,6 @@ import numpy as np from zoopt import Dimension, Objective, Parameter, Opt -from ..utils.utils import ( +from abl.utils.utils import ( confidence_dist, flatten, reform_idx, @@ -23,7 +23,7 @@ class ReasonerBase: | `"confidence"`. Any other options will raise a `NotImplementedError`. For detailed explanations of these options, refer to `_get_cost_list`. mapping : dict, optional - A mapping from label to index. If not provided, a default order-based mapping is + A mapping from index to label. If not provided, a default order-based mapping is created. use_zoopt : bool, optional Whether to use the Zoopt library during abductive reasoning. Default to False. @@ -44,6 +44,7 @@ class ReasonerBase: if not isinstance(mapping, dict): raise TypeError("mapping should be dict") self.mapping = mapping + self.remapping = dict(zip(self.mapping.values(), self.mapping.keys())) def _get_one_candidate(self, pred_pseudo_label, pred_prob, candidates): """ @@ -57,7 +58,7 @@ class ReasonerBase: Predicted pseudo label to be used for selecting a candidate. pred_prob : List[List[Any]] Predicted probabilities of the prediction (Each sublist contains the probability - values of all pseudo labels). + distribution over all pseudo labels). candidates : List[List[Any]] Multiple candidate abduction results. """ @@ -85,7 +86,7 @@ class ReasonerBase: Predicted pseudo label. pred_prob : List[List[Any]] Predicted probabilities of the prediction (Each sublist contains the probability - values of all pseudo labels). Used when distance function is "confidence". + distribution over all pseudo labels). Used when distance function is "confidence". candidates : List[List[Any]] Multiple candidate abduction results. """ @@ -93,7 +94,7 @@ class ReasonerBase: return hamming_dist(pred_pseudo_label, candidates) elif self.dist_func == "confidence": - candidates = [[self.mapping[x] for x in c] for c in candidates] + candidates = [[self.remapping[x] for x in c] for c in candidates] return confidence_dist(pred_prob, candidates) @@ -112,7 +113,7 @@ class ReasonerBase: Predicted pseudo label. pred_prob : List[List[Any]] Predicted probabilities of the prediction (Each sublist contains the probability - values of all pseudo labels). + distribution over all pseudo labels). y : Any Ground truth. max_revision_num : int @@ -177,7 +178,7 @@ class ReasonerBase: ---------- pred_prob : List[List[Any]] Predicted probabilities of the prediction (Each sublist contains the probability - values of all pseudo labels). + distribution over all pseudo labels). pred_pseudo_label : List[Any] Predicted pseudo label. y : any @@ -193,7 +194,7 @@ class ReasonerBase: Returns ------- List[Any] - The revised pseudo label through abductive reasoning, which is consistent with the + A revised pseudo label through abductive reasoning, which is consistent with the knowledge base. """ symbol_num = len(flatten(pred_pseudo_label))