diff --git a/abl/abducer/kb.py b/abl/abducer/kb.py index afb508d..675b200 100644 --- a/abl/abducer/kb.py +++ b/abl/abducer/kb.py @@ -17,7 +17,7 @@ import numpy as np from collections import defaultdict from itertools import product, combinations -from ..utils.utils import flatten, reform_idx, hamming_dist, check_equal, to_hashable, hashable_to_list +from utils.utils import flatten, reform_idx, hamming_dist, check_equal, to_hashable, hashable_to_list from multiprocessing import Pool @@ -25,7 +25,7 @@ from functools import lru_cache import pyswip class KBBase(ABC): - def __init__(self, pseudo_label_list, len_list=None, GKB_flag=False, max_err=0, cache_size=128): + def __init__(self, pseudo_label_list, len_list=None, GKB_flag=False, max_err=0, use_cache=True): # TODO:添加一下类型检查,比如 # if not isinstance(X, (np.ndarray, spmatrix)): # raise TypeError("X should be numpy array or sparse matrix") @@ -34,7 +34,7 @@ class KBBase(ABC): self.len_list = len_list self.GKB_flag = GKB_flag self.max_err = max_err - self.cache_size = cache_size + self.use_cache = use_cache if GKB_flag: self.base = {} @@ -147,32 +147,39 @@ class KBBase(ABC): return new_candidates def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address): - @lru_cache(maxsize=self.cache_size) - def _cached_abduce_by_search(pred_res, key, max_address_num, require_more_address): - pred_res = hashable_to_list(pred_res) - key = hashable_to_list(key) - - candidates = [] - for address_num in range(len(pred_res) + 1): - if address_num == 0: - if check_equal(self.logic_forward(pred_res), key, self.max_err): - candidates.append(pred_res) - else: - new_candidates = self._address(address_num, pred_res, key) - candidates += new_candidates - if len(candidates) > 0: - min_address_num = address_num - break - if address_num >= max_address_num: - return [] - - for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1): - if address_num > max_address_num: - return candidates + if self.use_cache: + return self._abduce_by_search_cache(pred_res, key, max_address_num, require_more_address) + else: + return self._abduce_by_search_no_cache(pred_res, key, max_address_num, require_more_address) + + @lru_cache(maxsize=None) + def _abduce_by_search_cache(self, pred_res, key, max_address_num, require_more_address): + return self._abduce_by_search_no_cache(pred_res, key, max_address_num, require_more_address) + + def _abduce_by_search_no_cache(self, pred_res, key, max_address_num, require_more_address): + pred_res = hashable_to_list(pred_res) + key = hashable_to_list(key) + + candidates = [] + for address_num in range(len(pred_res) + 1): + if address_num == 0: + if check_equal(self.logic_forward(pred_res), key, self.max_err): + candidates.append(pred_res) + else: new_candidates = self._address(address_num, pred_res, key) candidates += new_candidates - return candidates - return _cached_abduce_by_search(pred_res, key, max_address_num, require_more_address) + if len(candidates) > 0: + min_address_num = address_num + break + if address_num >= max_address_num: + return [] + + for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1): + if address_num > max_address_num: + return candidates + new_candidates = self._address(address_num, pred_res, key) + candidates += new_candidates + return candidates def _dict_len(self, dic): if not self.GKB_flag: