Browse Source

[MNT] use parameters of kb to initialize abl_cache

pull/4/head
Gao Enhao 2 years ago
parent
commit
7e5292eccb
2 changed files with 35 additions and 26 deletions
  1. +14
    -2
      abl/reasoning/kb.py
  2. +21
    -24
      abl/utils/cache.py

+ 14
- 2
abl/reasoning/kb.py View File

@@ -39,12 +39,24 @@ class KBBase(ABC):
reasoning) will be automatically set up.
"""

def __init__(self, pseudo_label_list, max_err=1e-10, use_cache=True):
def __init__(
self,
pseudo_label_list,
max_err=1e-10,
use_cache=True,
cache_file=None,
key_func=to_hashable,
max_cache_size=4096,
):
if not isinstance(pseudo_label_list, list):
raise TypeError("pseudo_label_list should be list")
self.pseudo_label_list = pseudo_label_list
self.max_err = max_err

self.use_cache = use_cache
self.cache_file = cache_file
self.key_func = key_func
self.max_cache_size = max_cache_size

@abstractmethod
def logic_forward(self, pseudo_label):
@@ -137,7 +149,7 @@ class KBBase(ABC):
new_candidates.extend(candidates)
return new_candidates

@abl_cache(max_size=4096)
@abl_cache()
def _abduce_by_search(self, pred_pseudo_label, y, max_revision_num, require_more_revision):
"""
Perform abductive reasoning by exhastive search. Specifically, begin with 0 and


+ 21
- 24
abl/utils/cache.py View File

@@ -11,13 +11,7 @@ PREV, NEXT, KEY, RESULT = 0, 1, 2, 3 # names for the link fields


class Cache(Generic[K, T]):
def __init__(
self,
func: Callable[[K], T],
cache_file: Union[None, str, PathLike] = None,
key_func: Callable[[K], Hashable] = lambda x: x,
max_size: int = 4096,
):
def __init__(self, func: Callable[[K], T]):
"""Create cache

:param func: Function this cache evaluates
@@ -26,9 +20,7 @@ class Cache(Generic[K, T]):
:param key_func: Convert the key into a hashable object if needed
"""
self.func = func
self.key_func = key_func

self._init_cache(cache_file, max_size)
self.has_init = False

def __getitem__(self, obj, *args) -> T:
return self.get_from_dict(obj, *args)
@@ -37,27 +29,35 @@ class Cache(Generic[K, T]):
"""Invalidate entire cache."""
self.cache_dict.clear()

def _init_cache(self, cache_file, max_size):
def _init_cache(self, obj):
if self.has_init:
return

self.cache = True
self.cache_dict = dict()
self.key_func = obj.key_func
self.cache_file = obj.cache_file
self.max_size = obj.max_cache_size

self.hits, self.misses, self.maxsize = 0, 0, max_size
self.hits, self.misses = 0, 0
self.full = False
self.root = [] # root of the circular doubly linked list
self.root[:] = [self.root, self.root, None, None]

if cache_file is not None:
with open(cache_file, "rb") as f:
if self.cache_file is not None:
with open(self.cache_file, "rb") as f:
cache_dict_from_file = pickle.load(f)
self.maxsize += len(cache_dict_from_file)
self.max_size += len(cache_dict_from_file)
print_log(
f"Max size of the cache has been enlarged to {self.maxsize}.", logger="current"
f"Max size of the cache has been enlarged to {self.max_size}.", logger="current"
)
for cache_key, result in cache_dict_from_file.items():
last = self.root[PREV]
link = [last, self.root, cache_key, result]
last[NEXT] = self.root[PREV] = self.cache_dict[cache_key] = link

self.has_init = True

def get_from_dict(self, obj, *args) -> T:
"""Implements dict based cache."""
pred_pseudo_label, y, *res_args = args
@@ -96,21 +96,18 @@ class Cache(Generic[K, T]):
last = self.root[PREV]
link = [last, self.root, cache_key, result]
last[NEXT] = self.root[PREV] = self.cache_dict[cache_key] = link
if isinstance(self.maxsize, int):
self.full = len(self.cache_dict) >= self.maxsize
if isinstance(self.max_size, int):
self.full = len(self.cache_dict) >= self.max_size
return result


def abl_cache(
cache_file: Union[None, str, PathLike] = None,
key_func: Callable[[K], Hashable] = to_hashable,
max_size: int = 4096,
):
def abl_cache():
def decorator(func):
cache_instance = Cache(func, cache_file, key_func, max_size)
cache_instance = Cache(func)

def wrapper(obj, *args):
if obj.use_cache:
cache_instance._init_cache(obj)
return cache_instance.get_from_dict(obj, *args)
else:
return func(obj, *args)


Loading…
Cancel
Save