From 3a7971dd8745b60c3d2785c25a55716377b19dd2 Mon Sep 17 00:00:00 2001 From: troyyyyy Date: Tue, 28 Mar 2023 10:56:27 +0800 Subject: [PATCH] Remove ClsKB and RegKB, add cache_size --- abl/abducer/kb.py | 115 ++++++++++++++++++++-------------------------- 1 file changed, 49 insertions(+), 66 deletions(-) diff --git a/abl/abducer/kb.py b/abl/abducer/kb.py index 1421a97..afb508d 100644 --- a/abl/abducer/kb.py +++ b/abl/abducer/kb.py @@ -85,9 +85,29 @@ class KBBase(ABC): else: return self._abduce_by_search(to_hashable(pred_res), to_hashable(key), max_address_num, require_more_address) - @abstractmethod def _find_candidate_GKB(self, pred_res, key): - pass + if self.max_err == 0: + return self.base[len(pred_res)][key] + else: + potential_candidates = self.base[len(pred_res)] + key_list = list(potential_candidates.keys()) + key_idx = bisect.bisect_left(key_list, key) + + all_candidates = [] + for idx in range(key_idx - 1, 0, -1): + k = key_list[idx] + if abs(k - key) <= self.max_err: + all_candidates += potential_candidates[k] + else: + break + + for idx in range(key_idx, len(key_list)): + k = key_list[idx] + if abs(k - key) <= self.max_err: + all_candidates += potential_candidates[k] + else: + break + return all_candidates def _abduce_by_GKB(self, pred_res, key, max_address_num, require_more_address): if self.base == {}: @@ -126,33 +146,34 @@ class KBBase(ABC): new_candidates += candidates return new_candidates - # TODO:在类初始化时应该有一个cache(默认True)的参数,用户可以指定是否用cache(若KB会变,那不能用cache) - @lru_cache(maxsize=None) def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address): - pred_res = hashable_to_list(pred_res) - key = hashable_to_list(key) - - candidates = [] - for address_num in range(len(pred_res) + 1): - if address_num == 0: - if check_equal(self.logic_forward(pred_res), key, self.max_err): - candidates.append(pred_res) - else: + @lru_cache(maxsize=self.cache_size) + def _cached_abduce_by_search(pred_res, key, max_address_num, require_more_address): + pred_res = hashable_to_list(pred_res) + key = hashable_to_list(key) + + candidates = [] + for address_num in range(len(pred_res) + 1): + if address_num == 0: + if check_equal(self.logic_forward(pred_res), key, self.max_err): + candidates.append(pred_res) + else: + new_candidates = self._address(address_num, pred_res, key) + candidates += new_candidates + if len(candidates) > 0: + min_address_num = address_num + break + if address_num >= max_address_num: + return [] + + for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1): + if address_num > max_address_num: + return candidates new_candidates = self._address(address_num, pred_res, key) candidates += new_candidates - if len(candidates) > 0: - min_address_num = address_num - break - if address_num >= max_address_num: - return [] - - for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1): - if address_num > max_address_num: - return candidates - new_candidates = self._address(address_num, pred_res, key) - candidates += new_candidates - return candidates - + return candidates + return _cached_abduce_by_search(pred_res, key, max_address_num, require_more_address) + def _dict_len(self, dic): if not self.GKB_flag: return 0 @@ -165,16 +186,7 @@ class KBBase(ABC): else: return sum(self._dict_len(v) for v in self.base.values()) - -class ClsKB(KBBase): - def __init__(self, pseudo_label_list, len_list, GKB_flag): - super().__init__(pseudo_label_list, len_list, GKB_flag) - - def _find_candidate_GKB(self, pred_res, key): - return self.base[len(pred_res)][key] - - -class add_KB(ClsKB): +class add_KB(KBBase): def __init__(self, pseudo_label_list=list(range(10)), len_list=[2], GKB_flag=False): super().__init__(pseudo_label_list, len_list, GKB_flag) @@ -215,9 +227,6 @@ class prolog_KB(KBBase): key_is_none_flag = key is None or (type(key) == list and key[0] is None) query_string += ",%s)." % key if not key_is_none_flag else ")." return query_string - - def _find_candidate_GKB(self, pred_res, key): - pass def address_by_idx(self, pred_res, key, address_idx): candidates = [] @@ -251,33 +260,7 @@ class HED_prolog_KB(prolog_KB): return rules -class RegKB(KBBase): - def __init__(self, pseudo_label_list, len_list, GKB_flag, max_err): - super().__init__(pseudo_label_list, len_list, GKB_flag, max_err) - - def _find_candidate_GKB(self, pred_res, key): - potential_candidates = self.base[len(pred_res)] - key_list = list(potential_candidates.keys()) - key_idx = bisect.bisect_left(key_list, key) - - all_candidates = [] - for idx in range(key_idx - 1, 0, -1): - k = key_list[idx] - if abs(k - key) <= self.max_err: - all_candidates += potential_candidates[k] - else: - break - - for idx in range(key_idx, len(key_list)): - k = key_list[idx] - if abs(k - key) <= self.max_err: - all_candidates += potential_candidates[k] - else: - break - return all_candidates - - -class HWF_KB(RegKB): +class HWF_KB(KBBase): def __init__( self, pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'],