diff --git a/abl/abducer/kb.py b/abl/abducer/kb.py index 675b200..c470af3 100644 --- a/abl/abducer/kb.py +++ b/abl/abducer/kb.py @@ -17,7 +17,7 @@ import numpy as np from collections import defaultdict from itertools import product, combinations -from utils.utils import flatten, reform_idx, hamming_dist, check_equal, to_hashable, hashable_to_list +from ..utils.utils import flatten, reform_idx, hamming_dist, check_equal, to_hashable, hashable_to_list from multiprocessing import Pool @@ -83,7 +83,10 @@ class KBBase(ABC): if self.GKB_flag: return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address) else: - return self._abduce_by_search(to_hashable(pred_res), to_hashable(key), max_address_num, require_more_address) + if not self.use_cache: + return self._abduce_by_search(pred_res, key, max_address_num, require_more_address) + else: + return self._abduce_by_search_cache(to_hashable(pred_res), to_hashable(key), max_address_num, require_more_address) def _find_candidate_GKB(self, pred_res, key): if self.max_err == 0: @@ -148,17 +151,8 @@ class KBBase(ABC): def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address): if self.use_cache: - return self._abduce_by_search_cache(pred_res, key, max_address_num, require_more_address) - else: - return self._abduce_by_search_no_cache(pred_res, key, max_address_num, require_more_address) - - @lru_cache(maxsize=None) - def _abduce_by_search_cache(self, pred_res, key, max_address_num, require_more_address): - return self._abduce_by_search_no_cache(pred_res, key, max_address_num, require_more_address) - - def _abduce_by_search_no_cache(self, pred_res, key, max_address_num, require_more_address): - pred_res = hashable_to_list(pred_res) - key = hashable_to_list(key) + pred_res = hashable_to_list(pred_res) + key = hashable_to_list(key) candidates = [] for address_num in range(len(pred_res) + 1): @@ -180,7 +174,11 @@ class KBBase(ABC): new_candidates = self._address(address_num, pred_res, key) candidates += new_candidates return candidates - + + @lru_cache(maxsize=None) + def _abduce_by_search_cache(self, pred_res, key, max_address_num, require_more_address): + return self._abduce_by_search(pred_res, key, max_address_num, require_more_address) + def _dict_len(self, dic): if not self.GKB_flag: return 0 @@ -194,8 +192,8 @@ class KBBase(ABC): return sum(self._dict_len(v) for v in self.base.values()) class add_KB(KBBase): - def __init__(self, pseudo_label_list=list(range(10)), len_list=[2], GKB_flag=False): - super().__init__(pseudo_label_list, len_list, GKB_flag) + def __init__(self, pseudo_label_list=list(range(10)), len_list=[2], GKB_flag=False, use_cache=True): + super().__init__(pseudo_label_list, len_list, GKB_flag, use_cache) def logic_forward(self, nums): return sum(nums) @@ -273,9 +271,10 @@ class HWF_KB(KBBase): pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], len_list=[1, 3, 5, 7], GKB_flag=False, - max_err=1e-3 + max_err=1e-3, + use_cache=True ): - super().__init__(pseudo_label_list, len_list, GKB_flag, max_err) + super().__init__(pseudo_label_list, len_list, GKB_flag, max_err, use_cache) def _valid_candidate(self, formula): if len(formula) % 2 == 0: