From c95faf043d3c2a2ad6fcccf85cf7c04573c3d306 Mon Sep 17 00:00:00 2001 From: Gao Enhao Date: Wed, 15 Nov 2023 15:56:42 +0800 Subject: [PATCH] [ENH] integrate choice of cache in to abl_cache --- abl/reasoning/kb.py | 201 +++++++++++++++++++++++--------------------- abl/utils/cache.py | 43 ++++------ 2 files changed, 121 insertions(+), 123 deletions(-) diff --git a/abl/reasoning/kb.py b/abl/reasoning/kb.py index e1df939..3e5ca9e 100644 --- a/abl/reasoning/kb.py +++ b/abl/reasoning/kb.py @@ -12,6 +12,7 @@ import pyswip from ..utils.utils import flatten, reform_idx, hamming_dist, to_hashable, restore_from_hashable from ..utils.cache import abl_cache + class KBBase(ABC): """ Base class for knowledge base. @@ -21,35 +22,36 @@ class KBBase(ABC): pseudo_label_list : list List of possible pseudo labels. max_err : float, optional - The upper tolerance limit when comparing the similarity between a candidate's logical - result. This is only applicable when the logical result is of a numerical type. - This is particularly relevant for regression problems where exact matches might not be - feasible. Defaults to 1e-10. + The upper tolerance limit when comparing the similarity between a candidate's logical + result. This is only applicable when the logical result is of a numerical type. + This is particularly relevant for regression problems where exact matches might not be + feasible. Defaults to 1e-10. use_cache : bool, optional - Whether to use a cache for previously abduced candidates to speed up subsequent + Whether to use a cache for previously abduced candidates to speed up subsequent operations. Defaults to True. - + Notes ----- - Users should inherit from this base class to build their own knowledge base. For the - user-build KB (an inherited subclass), it's only required for the user to provide the - `pseudo_label_list` and override the `logic_forward` function (specifying how to - perform logical reasoning). After that, other operations (e.g. how to perform abductive - reasoning) will be automatically set up. + Users should inherit from this base class to build their own knowledge base. For the + user-build KB (an inherited subclass), it's only required for the user to provide the + `pseudo_label_list` and override the `logic_forward` function (specifying how to + perform logical reasoning). After that, other operations (e.g. how to perform abductive + reasoning) will be automatically set up. """ + def __init__(self, pseudo_label_list, max_err=1e-10, use_cache=True): if not isinstance(pseudo_label_list, list): raise TypeError("pseudo_label_list should be list") self.pseudo_label_list = pseudo_label_list self.max_err = max_err - self.use_cache = use_cache + self.use_cache = use_cache @abstractmethod def logic_forward(self, pseudo_label): """ - How to perform (deductive) logical reasoning, i.e. matching each pseudo label to + How to perform (deductive) logical reasoning, i.e. matching each pseudo label to their logical result. Users are required to provide this. - + Parameters ---------- pred_pseudo_label : List[Any] @@ -70,23 +72,22 @@ class KBBase(ABC): max_revision_num : int The upper limit on the number of revisions. require_more_revision : int, optional - Specifies additional number of revisions permitted beyond the minimum required. + Specifies additional number of revisions permitted beyond the minimum required. Defaults to 0. Returns ------- List[List[Any]] - A list of candidates, i.e. revised pseudo labels that are consistent with the + A list of candidates, i.e. revised pseudo labels that are consistent with the knowledge base. """ - if self.use_cache: - return self._abduce_by_search_cache(to_hashable(pred_pseudo_label), - to_hashable(y), - max_revision_num, require_more_revision) - else: - return self._abduce_by_search(pred_pseudo_label, y, - max_revision_num, require_more_revision) - + # if self.use_cache: + # return self._abduce_by_search_cache(to_hashable(pred_pseudo_label), + # to_hashable(y), + # max_revision_num, require_more_revision) + # else: + return self._abduce_by_search(pred_pseudo_label, y, max_revision_num, require_more_revision) + def _check_equal(self, logic_result, y): """ Check whether the logical result of a candidate is equal to the ground truth @@ -94,12 +95,12 @@ class KBBase(ABC): """ if logic_result == None: return False - + if isinstance(logic_result, (int, float)) and isinstance(y, (int, float)): return abs(logic_result - y) <= self.max_err else: return logic_result == y - + def revise_at_idx(self, pred_pseudo_label, y, revision_idx): """ Revise the predicted pseudo label at specified index positions. @@ -125,7 +126,7 @@ class KBBase(ABC): def _revision(self, revision_num, pred_pseudo_label, y): """ - For a specified number of pseudo label to revise, iterate through all possible + For a specified number of pseudo label to revise, iterate through all possible indices to find any candidates that are consistent with the knowledge base. """ new_candidates = [] @@ -136,12 +137,13 @@ class KBBase(ABC): new_candidates.extend(candidates) return new_candidates - def _abduce_by_search(self, pred_pseudo_label, y, max_revision_num, require_more_revision): + @abl_cache(max_size=4096) + def _abduce_by_search(self, pred_pseudo_label, y, max_revision_num, require_more_revision): """ - Perform abductive reasoning by exhastive search. Specifically, begin with 0 and - continuously increase the number of pseudo labels to revise, until candidates + Perform abductive reasoning by exhastive search. Specifically, begin with 0 and + continuously increase the number of pseudo labels to revise, until candidates that are consistent with the knowledge base are found. - + Parameters ---------- pred_pseudo_label : List[Any] @@ -151,16 +153,16 @@ class KBBase(ABC): max_revision_num : int The upper limit on the number of revisions. require_more_revision : int - If larger than 0, then after having found any candidates consistent with the - knowledge base, continue to increase the number pseudo labels to revise to + If larger than 0, then after having found any candidates consistent with the + knowledge base, continue to increase the number pseudo labels to revise to get more possible consistent candidates. Returns ------- List[List[Any]] - A list of candidates, i.e. revised pseudo label that are consistent with the + A list of candidates, i.e. revised pseudo label that are consistent with the knowledge base. - """ + """ candidates = [] for revision_num in range(len(pred_pseudo_label) + 1): if revision_num == 0 and self._check_equal(self.logic_forward(pred_pseudo_label), y): @@ -173,20 +175,22 @@ class KBBase(ABC): if revision_num >= max_revision_num: return [] - for revision_num in range(min_revision_num + 1, min_revision_num + require_more_revision + 1): + for revision_num in range( + min_revision_num + 1, min_revision_num + require_more_revision + 1 + ): if revision_num > max_revision_num: return candidates candidates.extend(self._revision(revision_num, pred_pseudo_label, y)) return candidates - - @abl_cache(max_size=4096) - def _abduce_by_search_cache(self, pred_pseudo_label, y, max_revision_num, require_more_revision): - """ - `_abduce_by_search` with cache. - """ - pred_pseudo_label = restore_from_hashable(pred_pseudo_label) - y = restore_from_hashable(y) - return self._abduce_by_search(pred_pseudo_label, y, max_revision_num, require_more_revision) + + # @abl_cache(max_size=4096) + # def _abduce_by_search_cache(self, pred_pseudo_label, y, max_revision_num, require_more_revision): + # """ + # `_abduce_by_search` with cache. + # """ + # pred_pseudo_label = restore_from_hashable(pred_pseudo_label) + # y = restore_from_hashable(y) + # return self._abduce_by_search(pred_pseudo_label, y, max_revision_num, require_more_revision) def __repr__(self): return ( @@ -195,13 +199,13 @@ class KBBase(ABC): f"max_err={self.max_err!r}, " f"use_cache={self.use_cache!r}." ) - - + + class GroundKB(KBBase): """ - Knowledge base with a ground KB (GKB). Ground KB is a knowledge base prebuilt upon - class initialization, storing all potential candidates along with their respective - logical result. Ground KB can accelerate abductive reasoning in `abduce_candidates`. + Knowledge base with a ground KB (GKB). Ground KB is a knowledge base prebuilt upon + class initialization, storing all potential candidates along with their respective + logical result. Ground KB can accelerate abductive reasoning in `abduce_candidates`. Parameters ---------- @@ -211,15 +215,16 @@ class GroundKB(KBBase): List of possible lengths of pseudo label. max_err : float, optional Refer to class `KBBase`. - + Notes ----- - Users can also inherit from this class to build their own knowledge base. Similar - to `KBBase`, users are only required to provide the `pseudo_label_list` and override + Users can also inherit from this class to build their own knowledge base. Similar + to `KBBase`, users are only required to provide the `pseudo_label_list` and override the `logic_forward` function. Additionally, users should provide the `GKB_len_list`. - After that, other operations (e.g. auto-construction of GKB, and how to perform + After that, other operations (e.g. auto-construction of GKB, and how to perform abductive reasoning) will be automatically set up. """ + def __init__(self, pseudo_label_list, GKB_len_list, max_err=1e-10): super().__init__(pseudo_label_list, max_err) if not isinstance(GKB_len_list, list): @@ -229,7 +234,6 @@ class GroundKB(KBBase): X, Y = self._get_GKB() for x, y in zip(X, Y): self.GKB.setdefault(len(x), defaultdict(list))[y].append(x) - def _get_XY_list(self, args): pre_x, post_x_it = args[0], args[1] @@ -259,21 +263,21 @@ class GroundKB(KBBase): part_X, part_Y = zip(*XY_list) X.extend(part_X) Y.extend(part_Y) - if Y and isinstance(Y[0], (int, float)): + if Y and isinstance(Y[0], (int, float)): X, Y = zip(*sorted(zip(X, Y), key=lambda pair: pair[1])) return X, Y - + def abduce_candidates(self, pred_pseudo_label, y, max_revision_num, require_more_revision=0): """ - Perform abductive reasoning by directly retrieving consistent candidates from - the prebuilt GKB. In this way, the time-consuming exhaustive search can be + Perform abductive reasoning by directly retrieving consistent candidates from + the prebuilt GKB. In this way, the time-consuming exhaustive search can be avoided. - This is an overridden function. For more information about the parameters and + This is an overridden function. For more information about the parameters and returns, refer to the function of the same name in class `KBBase`. """ if self.GKB == {} or len(pred_pseudo_label) not in self.GKB_len_list: return [] - + all_candidates = self._find_candidate_GKB(pred_pseudo_label, y) if len(all_candidates) == 0: return [] @@ -284,29 +288,30 @@ class GroundKB(KBBase): idxs = np.where(cost_list <= revision_num)[0] candidates = [all_candidates[idx] for idx in idxs] return candidates - + def _find_candidate_GKB(self, pred_pseudo_label, y): """ - Retrieve consistent candidates from the prebuilt GKB. For numerical logical results, - return all candidates whose logical results fall within the + Retrieve consistent candidates from the prebuilt GKB. For numerical logical results, + return all candidates whose logical results fall within the [y - max_err, y + max_err] range. """ if isinstance(y, (int, float)): potential_candidates = self.GKB[len(pred_pseudo_label)] key_list = list(potential_candidates.keys()) - + low_key = bisect.bisect_left(key_list, y - self.max_err) high_key = bisect.bisect_right(key_list, y + self.max_err) - all_candidates = [candidate - for key in key_list[low_key:high_key] - for candidate in potential_candidates[key]] + all_candidates = [ + candidate + for key in key_list[low_key:high_key] + for candidate in potential_candidates[key] + ] return all_candidates - + else: return self.GKB[len(pred_pseudo_label)][y] - - + def __repr__(self): return ( f"{self.__class__.__name__} is a KB with " @@ -321,78 +326,80 @@ class GroundKB(KBBase): class PrologKB(KBBase): """ Knowledge base provided by a Prolog (.pl) file. - + Parameters ---------- pseudo_label_list : list Refer to class `KBBase`. - pl_file : - Prolog file containing the KB. + pl_file : + Prolog file containing the KB. max_err : float, optional Refer to class `KBBase`. - + Notes ----- - Users can instantiate this class to build their own knowledge base. During the + Users can instantiate this class to build their own knowledge base. During the instantiation, users are only required to provide the `pseudo_label_list` and `pl_file`. - To use the default logic forward and abductive reasoning methods in this class, in the - Prolog (.pl) file, there needs to be a rule which is strictly formatted as + To use the default logic forward and abductive reasoning methods in this class, in the + Prolog (.pl) file, there needs to be a rule which is strictly formatted as `logic_forward(Pseudo_labels, Res).`, e.g., `logic_forward([A,B], C) :- C is A+B`. - For specifics, refer to the `logic_forward` and `get_query_string` functions in this + For specifics, refer to the `logic_forward` and `get_query_string` functions in this class. Users are also welcome to override related functions for more flexible support. """ + def __init__(self, pseudo_label_list, pl_file): super().__init__(pseudo_label_list) self.pl_file = pl_file self.prolog = pyswip.Prolog() - + if not os.path.exists(self.pl_file): raise FileNotFoundError(f"The Prolog file {self.pl_file} does not exist.") self.prolog.consult(self.pl_file) def logic_forward(self, pseudo_labels): """ - Consult prolog with the query `logic_forward(pseudo_labels, Res).`, and set the - returned `Res` as the logical results. To use this default function, there must be - a Prolog `log_forward` method in the pl file to perform logical. reasoning. + Consult prolog with the query `logic_forward(pseudo_labels, Res).`, and set the + returned `Res` as the logical results. To use this default function, there must be + a Prolog `log_forward` method in the pl file to perform logical. reasoning. Otherwise, users would override this function. """ - result = list(self.prolog.query("logic_forward(%s, Res)." % pseudo_labels))[0]['Res'] - if result == 'true': + result = list(self.prolog.query("logic_forward(%s, Res)." % pseudo_labels))[0]["Res"] + if result == "true": return True - elif result == 'false': + elif result == "false": return False return result - + def _revision_pred_pseudo_label(self, pred_pseudo_label, revision_idx): import re + revision_pred_pseudo_label = pred_pseudo_label.copy() revision_pred_pseudo_label = flatten(revision_pred_pseudo_label) - + for idx in revision_idx: - revision_pred_pseudo_label[idx] = 'P' + str(idx) + revision_pred_pseudo_label[idx] = "P" + str(idx) revision_pred_pseudo_label = reform_idx(revision_pred_pseudo_label, pred_pseudo_label) - + regex = r"'P\d+'" return re.sub(regex, lambda x: x.group().replace("'", ""), str(revision_pred_pseudo_label)) - + def get_query_string(self, pred_pseudo_label, y, revision_idx): """ - Consult prolog with `logic_forward([kept_labels, Revise_labels], Res).`, and set - the returned `Revise_labels` together with the kept labels as the candidates. This is + Consult prolog with `logic_forward([kept_labels, Revise_labels], Res).`, and set + the returned `Revise_labels` together with the kept labels as the candidates. This is a default fuction for demo, users would override this function to adapt to their own - Prolog file. + Prolog file. """ query_string = "logic_forward(" query_string += self._revision_pred_pseudo_label(pred_pseudo_label, revision_idx) key_is_none_flag = y is None or (type(y) == list and y[0] is None) query_string += ",%s)." % y if not key_is_none_flag else ")." return query_string - + def revise_at_idx(self, pred_pseudo_label, y, revision_idx): """ Revise the predicted pseudo label at specified index positions by querying Prolog. - This is an overridden function. For more information about the parameters, refer to + This is an overridden function. For more information about the parameters, refer to the function of the same name in class `KBBase`. """ candidates = [] @@ -414,4 +421,4 @@ class PrologKB(KBBase): f"pseudo_label_list={self.pseudo_label_list!r}, " f"defined by " f"Prolog file {self.pl_file!r}." - ) \ No newline at end of file + ) diff --git a/abl/utils/cache.py b/abl/utils/cache.py index dbf60d0..a342d0a 100644 --- a/abl/utils/cache.py +++ b/abl/utils/cache.py @@ -3,6 +3,7 @@ from os import PathLike from typing import Callable, Generic, Hashable, TypeVar, Union from .logger import print_log +from .utils import to_hashable K = TypeVar("K") T = TypeVar("T") @@ -13,7 +14,6 @@ class Cache(Generic[K, T]): def __init__( self, func: Callable[[K], T], - cache: bool = True, cache_file: Union[None, str, PathLike] = None, key_func: Callable[[K], Hashable] = lambda x: x, max_size: int = 4096, @@ -27,23 +27,15 @@ class Cache(Generic[K, T]): """ self.func = func self.key_func = key_func - self.cache = cache - if cache is True or cache_file is not None: - print_log("Caching is activated", logger="current") - self._init_cache(cache_file, max_size) - self.first = self.get_from_dict - else: - self.first = self.func - def __getitem__(self, item: K, *args) -> T: - return self.first(item, *args) + self._init_cache(cache_file, max_size) + + def __getitem__(self, obj, *args) -> T: + return self.get_from_dict(obj, *args) - def invalidate(self): + def clear_cache(self): """Invalidate entire cache.""" self.cache_dict.clear() - if self.cache_file: - for p in self.cache_root.iterdir(): - p.unlink() def _init_cache(self, cache_file, max_size): self.cache = True @@ -66,13 +58,10 @@ class Cache(Generic[K, T]): link = [last, self.root, cache_key, result] last[NEXT] = self.root[PREV] = self.cache_dict[cache_key] = link - def get(self, obj, item: K, *args) -> T: - return self.first(obj, item, *args) - - def get_from_dict(self, obj, item: K, *args) -> T: + def get_from_dict(self, obj, *args) -> T: """Implements dict based cache.""" - # result = self.func(obj, item, *args) - cache_key = (self.key_func(item), *args) + pred_pseudo_label, y, *res_args = args + cache_key = (self.key_func(pred_pseudo_label), self.key_func(y), *res_args) link = self.cache_dict.get(cache_key) if link is not None: # Move the link to the front of the circular queue @@ -87,7 +76,7 @@ class Cache(Generic[K, T]): return result self.misses += 1 - result = self.func(obj, item, *args) + result = self.func(obj, *args) if self.full: # Use the old root to store the new key and result. @@ -113,16 +102,18 @@ class Cache(Generic[K, T]): def abl_cache( - cache: bool = True, cache_file: Union[None, str, PathLike] = None, - key_func: Callable[[K], Hashable] = lambda x: x, + key_func: Callable[[K], Hashable] = to_hashable, max_size: int = 4096, ): def decorator(func): - cache_instance = Cache(func, cache, cache_file, key_func, max_size) + cache_instance = Cache(func, cache_file, key_func, max_size) - def wrapper(self, *args, **kwargs): - return cache_instance.get(self, *args, **kwargs) + def wrapper(obj, *args): + if obj.use_cache: + return cache_instance.get_from_dict(obj, *args) + else: + return func(obj, *args) return wrapper