From 7e5292eccbe3f03ae1660bb0d631a44f7cb8b5c4 Mon Sep 17 00:00:00 2001 From: Gao Enhao Date: Wed, 15 Nov 2023 16:55:21 +0800 Subject: [PATCH] [MNT] use parameters of kb to initialize abl_cache --- abl/reasoning/kb.py | 16 ++++++++++++++-- abl/utils/cache.py | 45 +++++++++++++++++++++------------------------ 2 files changed, 35 insertions(+), 26 deletions(-) diff --git a/abl/reasoning/kb.py b/abl/reasoning/kb.py index 3e5ca9e..1bca1b3 100644 --- a/abl/reasoning/kb.py +++ b/abl/reasoning/kb.py @@ -39,12 +39,24 @@ class KBBase(ABC): reasoning) will be automatically set up. """ - def __init__(self, pseudo_label_list, max_err=1e-10, use_cache=True): + def __init__( + self, + pseudo_label_list, + max_err=1e-10, + use_cache=True, + cache_file=None, + key_func=to_hashable, + max_cache_size=4096, + ): if not isinstance(pseudo_label_list, list): raise TypeError("pseudo_label_list should be list") self.pseudo_label_list = pseudo_label_list self.max_err = max_err + self.use_cache = use_cache + self.cache_file = cache_file + self.key_func = key_func + self.max_cache_size = max_cache_size @abstractmethod def logic_forward(self, pseudo_label): @@ -137,7 +149,7 @@ class KBBase(ABC): new_candidates.extend(candidates) return new_candidates - @abl_cache(max_size=4096) + @abl_cache() def _abduce_by_search(self, pred_pseudo_label, y, max_revision_num, require_more_revision): """ Perform abductive reasoning by exhastive search. Specifically, begin with 0 and diff --git a/abl/utils/cache.py b/abl/utils/cache.py index a342d0a..418c93f 100644 --- a/abl/utils/cache.py +++ b/abl/utils/cache.py @@ -11,13 +11,7 @@ PREV, NEXT, KEY, RESULT = 0, 1, 2, 3 # names for the link fields class Cache(Generic[K, T]): - def __init__( - self, - func: Callable[[K], T], - cache_file: Union[None, str, PathLike] = None, - key_func: Callable[[K], Hashable] = lambda x: x, - max_size: int = 4096, - ): + def __init__(self, func: Callable[[K], T]): """Create cache :param func: Function this cache evaluates @@ -26,9 +20,7 @@ class Cache(Generic[K, T]): :param key_func: Convert the key into a hashable object if needed """ self.func = func - self.key_func = key_func - - self._init_cache(cache_file, max_size) + self.has_init = False def __getitem__(self, obj, *args) -> T: return self.get_from_dict(obj, *args) @@ -37,27 +29,35 @@ class Cache(Generic[K, T]): """Invalidate entire cache.""" self.cache_dict.clear() - def _init_cache(self, cache_file, max_size): + def _init_cache(self, obj): + if self.has_init: + return + self.cache = True self.cache_dict = dict() + self.key_func = obj.key_func + self.cache_file = obj.cache_file + self.max_size = obj.max_cache_size - self.hits, self.misses, self.maxsize = 0, 0, max_size + self.hits, self.misses = 0, 0 self.full = False self.root = [] # root of the circular doubly linked list self.root[:] = [self.root, self.root, None, None] - if cache_file is not None: - with open(cache_file, "rb") as f: + if self.cache_file is not None: + with open(self.cache_file, "rb") as f: cache_dict_from_file = pickle.load(f) - self.maxsize += len(cache_dict_from_file) + self.max_size += len(cache_dict_from_file) print_log( - f"Max size of the cache has been enlarged to {self.maxsize}.", logger="current" + f"Max size of the cache has been enlarged to {self.max_size}.", logger="current" ) for cache_key, result in cache_dict_from_file.items(): last = self.root[PREV] link = [last, self.root, cache_key, result] last[NEXT] = self.root[PREV] = self.cache_dict[cache_key] = link + self.has_init = True + def get_from_dict(self, obj, *args) -> T: """Implements dict based cache.""" pred_pseudo_label, y, *res_args = args @@ -96,21 +96,18 @@ class Cache(Generic[K, T]): last = self.root[PREV] link = [last, self.root, cache_key, result] last[NEXT] = self.root[PREV] = self.cache_dict[cache_key] = link - if isinstance(self.maxsize, int): - self.full = len(self.cache_dict) >= self.maxsize + if isinstance(self.max_size, int): + self.full = len(self.cache_dict) >= self.max_size return result -def abl_cache( - cache_file: Union[None, str, PathLike] = None, - key_func: Callable[[K], Hashable] = to_hashable, - max_size: int = 4096, -): +def abl_cache(): def decorator(func): - cache_instance = Cache(func, cache_file, key_func, max_size) + cache_instance = Cache(func) def wrapper(obj, *args): if obj.use_cache: + cache_instance._init_cache(obj) return cache_instance.get_from_dict(obj, *args) else: return func(obj, *args)