From 3bfd2c8d11714baa4653144e9625d68bf61c8f93 Mon Sep 17 00:00:00 2001 From: Gao Enhao Date: Tue, 14 Nov 2023 22:29:13 +0800 Subject: [PATCH] [ENH] change Cache to decorator abl_cache --- abl/utils/__init__.py | 2 +- abl/utils/cache.py | 31 ++++++++++++++++++++++++------- 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/abl/utils/__init__.py b/abl/utils/__init__.py index 526b50b..bbd0d81 100644 --- a/abl/utils/__init__.py +++ b/abl/utils/__init__.py @@ -1,3 +1,3 @@ -from .cache import Cache +from .cache import Cache, abl_cache from .logger import ABLLogger, print_log from .utils import * diff --git a/abl/utils/cache.py b/abl/utils/cache.py index f4b3b0c..dbf60d0 100644 --- a/abl/utils/cache.py +++ b/abl/utils/cache.py @@ -1,6 +1,5 @@ import pickle from os import PathLike -from pathlib import Path from typing import Callable, Generic, Hashable, TypeVar, Union from .logger import print_log @@ -14,8 +13,8 @@ class Cache(Generic[K, T]): def __init__( self, func: Callable[[K], T], - cache: bool, - cache_file: Union[None, str, PathLike], + cache: bool = True, + cache_file: Union[None, str, PathLike] = None, key_func: Callable[[K], Hashable] = lambda x: x, max_size: int = 4096, ): @@ -67,11 +66,12 @@ class Cache(Generic[K, T]): link = [last, self.root, cache_key, result] last[NEXT] = self.root[PREV] = self.cache_dict[cache_key] = link - def get(self, item: K, *args) -> T: - return self.first(item, *args) + def get(self, obj, item: K, *args) -> T: + return self.first(obj, item, *args) - def get_from_dict(self, item: K, *args) -> T: + def get_from_dict(self, obj, item: K, *args) -> T: """Implements dict based cache.""" + # result = self.func(obj, item, *args) cache_key = (self.key_func(item), *args) link = self.cache_dict.get(cache_key) if link is not None: @@ -87,7 +87,7 @@ class Cache(Generic[K, T]): return result self.misses += 1 - result = self.func(item, *args) + result = self.func(obj, item, *args) if self.full: # Use the old root to store the new key and result. @@ -110,3 +110,20 @@ class Cache(Generic[K, T]): if isinstance(self.maxsize, int): self.full = len(self.cache_dict) >= self.maxsize return result + + +def abl_cache( + cache: bool = True, + cache_file: Union[None, str, PathLike] = None, + key_func: Callable[[K], Hashable] = lambda x: x, + max_size: int = 4096, +): + def decorator(func): + cache_instance = Cache(func, cache, cache_file, key_func, max_size) + + def wrapper(self, *args, **kwargs): + return cache_instance.get(self, *args, **kwargs) + + return wrapper + + return decorator