| @@ -85,9 +85,29 @@ class KBBase(ABC): | |||
| else: | |||
| return self._abduce_by_search(to_hashable(pred_res), to_hashable(key), max_address_num, require_more_address) | |||
| @abstractmethod | |||
| def _find_candidate_GKB(self, pred_res, key): | |||
| pass | |||
| if self.max_err == 0: | |||
| return self.base[len(pred_res)][key] | |||
| else: | |||
| potential_candidates = self.base[len(pred_res)] | |||
| key_list = list(potential_candidates.keys()) | |||
| key_idx = bisect.bisect_left(key_list, key) | |||
| all_candidates = [] | |||
| for idx in range(key_idx - 1, 0, -1): | |||
| k = key_list[idx] | |||
| if abs(k - key) <= self.max_err: | |||
| all_candidates += potential_candidates[k] | |||
| else: | |||
| break | |||
| for idx in range(key_idx, len(key_list)): | |||
| k = key_list[idx] | |||
| if abs(k - key) <= self.max_err: | |||
| all_candidates += potential_candidates[k] | |||
| else: | |||
| break | |||
| return all_candidates | |||
| def _abduce_by_GKB(self, pred_res, key, max_address_num, require_more_address): | |||
| if self.base == {}: | |||
| @@ -126,33 +146,34 @@ class KBBase(ABC): | |||
| new_candidates += candidates | |||
| return new_candidates | |||
| # TODO:在类初始化时应该有一个cache(默认True)的参数,用户可以指定是否用cache(若KB会变,那不能用cache) | |||
| @lru_cache(maxsize=None) | |||
| def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address): | |||
| pred_res = hashable_to_list(pred_res) | |||
| key = hashable_to_list(key) | |||
| candidates = [] | |||
| for address_num in range(len(pred_res) + 1): | |||
| if address_num == 0: | |||
| if check_equal(self.logic_forward(pred_res), key, self.max_err): | |||
| candidates.append(pred_res) | |||
| else: | |||
| @lru_cache(maxsize=self.cache_size) | |||
| def _cached_abduce_by_search(pred_res, key, max_address_num, require_more_address): | |||
| pred_res = hashable_to_list(pred_res) | |||
| key = hashable_to_list(key) | |||
| candidates = [] | |||
| for address_num in range(len(pred_res) + 1): | |||
| if address_num == 0: | |||
| if check_equal(self.logic_forward(pred_res), key, self.max_err): | |||
| candidates.append(pred_res) | |||
| else: | |||
| new_candidates = self._address(address_num, pred_res, key) | |||
| candidates += new_candidates | |||
| if len(candidates) > 0: | |||
| min_address_num = address_num | |||
| break | |||
| if address_num >= max_address_num: | |||
| return [] | |||
| for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1): | |||
| if address_num > max_address_num: | |||
| return candidates | |||
| new_candidates = self._address(address_num, pred_res, key) | |||
| candidates += new_candidates | |||
| if len(candidates) > 0: | |||
| min_address_num = address_num | |||
| break | |||
| if address_num >= max_address_num: | |||
| return [] | |||
| for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1): | |||
| if address_num > max_address_num: | |||
| return candidates | |||
| new_candidates = self._address(address_num, pred_res, key) | |||
| candidates += new_candidates | |||
| return candidates | |||
| return candidates | |||
| return _cached_abduce_by_search(pred_res, key, max_address_num, require_more_address) | |||
| def _dict_len(self, dic): | |||
| if not self.GKB_flag: | |||
| return 0 | |||
| @@ -165,16 +186,7 @@ class KBBase(ABC): | |||
| else: | |||
| return sum(self._dict_len(v) for v in self.base.values()) | |||
| class ClsKB(KBBase): | |||
| def __init__(self, pseudo_label_list, len_list, GKB_flag): | |||
| super().__init__(pseudo_label_list, len_list, GKB_flag) | |||
| def _find_candidate_GKB(self, pred_res, key): | |||
| return self.base[len(pred_res)][key] | |||
| class add_KB(ClsKB): | |||
| class add_KB(KBBase): | |||
| def __init__(self, pseudo_label_list=list(range(10)), len_list=[2], GKB_flag=False): | |||
| super().__init__(pseudo_label_list, len_list, GKB_flag) | |||
| @@ -215,9 +227,6 @@ class prolog_KB(KBBase): | |||
| key_is_none_flag = key is None or (type(key) == list and key[0] is None) | |||
| query_string += ",%s)." % key if not key_is_none_flag else ")." | |||
| return query_string | |||
| def _find_candidate_GKB(self, pred_res, key): | |||
| pass | |||
| def address_by_idx(self, pred_res, key, address_idx): | |||
| candidates = [] | |||
| @@ -251,33 +260,7 @@ class HED_prolog_KB(prolog_KB): | |||
| return rules | |||
| class RegKB(KBBase): | |||
| def __init__(self, pseudo_label_list, len_list, GKB_flag, max_err): | |||
| super().__init__(pseudo_label_list, len_list, GKB_flag, max_err) | |||
| def _find_candidate_GKB(self, pred_res, key): | |||
| potential_candidates = self.base[len(pred_res)] | |||
| key_list = list(potential_candidates.keys()) | |||
| key_idx = bisect.bisect_left(key_list, key) | |||
| all_candidates = [] | |||
| for idx in range(key_idx - 1, 0, -1): | |||
| k = key_list[idx] | |||
| if abs(k - key) <= self.max_err: | |||
| all_candidates += potential_candidates[k] | |||
| else: | |||
| break | |||
| for idx in range(key_idx, len(key_list)): | |||
| k = key_list[idx] | |||
| if abs(k - key) <= self.max_err: | |||
| all_candidates += potential_candidates[k] | |||
| else: | |||
| break | |||
| return all_candidates | |||
| class HWF_KB(RegKB): | |||
| class HWF_KB(KBBase): | |||
| def __init__( | |||
| self, | |||
| pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], | |||