| @@ -17,7 +17,7 @@ import numpy as np | |||
| from collections import defaultdict | |||
| from itertools import product, combinations | |||
| from utils.utils import flatten, reform_idx, hamming_dist, check_equal, to_hashable, hashable_to_list | |||
| from ..utils.utils import flatten, reform_idx, hamming_dist, check_equal, to_hashable, hashable_to_list | |||
| from multiprocessing import Pool | |||
| @@ -83,7 +83,10 @@ class KBBase(ABC): | |||
| if self.GKB_flag: | |||
| return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address) | |||
| else: | |||
| return self._abduce_by_search(to_hashable(pred_res), to_hashable(key), max_address_num, require_more_address) | |||
| if not self.use_cache: | |||
| return self._abduce_by_search(pred_res, key, max_address_num, require_more_address) | |||
| else: | |||
| return self._abduce_by_search_cache(to_hashable(pred_res), to_hashable(key), max_address_num, require_more_address) | |||
| def _find_candidate_GKB(self, pred_res, key): | |||
| if self.max_err == 0: | |||
| @@ -148,17 +151,8 @@ class KBBase(ABC): | |||
| def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address): | |||
| if self.use_cache: | |||
| return self._abduce_by_search_cache(pred_res, key, max_address_num, require_more_address) | |||
| else: | |||
| return self._abduce_by_search_no_cache(pred_res, key, max_address_num, require_more_address) | |||
| @lru_cache(maxsize=None) | |||
| def _abduce_by_search_cache(self, pred_res, key, max_address_num, require_more_address): | |||
| return self._abduce_by_search_no_cache(pred_res, key, max_address_num, require_more_address) | |||
| def _abduce_by_search_no_cache(self, pred_res, key, max_address_num, require_more_address): | |||
| pred_res = hashable_to_list(pred_res) | |||
| key = hashable_to_list(key) | |||
| pred_res = hashable_to_list(pred_res) | |||
| key = hashable_to_list(key) | |||
| candidates = [] | |||
| for address_num in range(len(pred_res) + 1): | |||
| @@ -180,7 +174,11 @@ class KBBase(ABC): | |||
| new_candidates = self._address(address_num, pred_res, key) | |||
| candidates += new_candidates | |||
| return candidates | |||
| @lru_cache(maxsize=None) | |||
| def _abduce_by_search_cache(self, pred_res, key, max_address_num, require_more_address): | |||
| return self._abduce_by_search(pred_res, key, max_address_num, require_more_address) | |||
| def _dict_len(self, dic): | |||
| if not self.GKB_flag: | |||
| return 0 | |||
| @@ -194,8 +192,8 @@ class KBBase(ABC): | |||
| return sum(self._dict_len(v) for v in self.base.values()) | |||
| class add_KB(KBBase): | |||
| def __init__(self, pseudo_label_list=list(range(10)), len_list=[2], GKB_flag=False): | |||
| super().__init__(pseudo_label_list, len_list, GKB_flag) | |||
| def __init__(self, pseudo_label_list=list(range(10)), len_list=[2], GKB_flag=False, use_cache=True): | |||
| super().__init__(pseudo_label_list, len_list, GKB_flag, use_cache) | |||
| def logic_forward(self, nums): | |||
| return sum(nums) | |||
| @@ -273,9 +271,10 @@ class HWF_KB(KBBase): | |||
| pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], | |||
| len_list=[1, 3, 5, 7], | |||
| GKB_flag=False, | |||
| max_err=1e-3 | |||
| max_err=1e-3, | |||
| use_cache=True | |||
| ): | |||
| super().__init__(pseudo_label_list, len_list, GKB_flag, max_err) | |||
| super().__init__(pseudo_label_list, len_list, GKB_flag, max_err, use_cache) | |||
| def _valid_candidate(self, formula): | |||
| if len(formula) % 2 == 0: | |||