diff --git a/abl/abducer/kb.py b/abl/abducer/kb.py index 6cecba4..333901e 100644 --- a/abl/abducer/kb.py +++ b/abl/abducer/kb.py @@ -113,7 +113,6 @@ class KBBase(ABC): min_address_num = 0 all_candidates_save = [] cost_list_save = [] - for p_res, k in zip(pred_res, key): if len(p_res) not in self.len_list: return [] @@ -136,7 +135,6 @@ class KBBase(ABC): def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False): candidates = [] abduce_c = product(self.pseudo_label_list, repeat=len(address_idx)) - if multiple_predictions: save_pred_res = pred_res pred_res = flatten(pred_res) @@ -145,10 +143,8 @@ class KBBase(ABC): candidate = pred_res.copy() for i, idx in enumerate(address_idx): candidate[idx] = c[i] - if multiple_predictions: candidate = reform_idx(candidate, save_pred_res) - if check_equal(self._logic_forward(candidate, multiple_predictions), key, self.max_err): candidates.append(candidate) return candidates @@ -167,15 +163,10 @@ class KBBase(ABC): @lru_cache(maxsize=100) def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): - # if self.abduce_cache: - # candidates = self._get_abduce_cache(pred_res, key, max_address_num, require_more_address, multiple_predictions) - # if candidates is not None: - # return candidates pred_res = hashable_to_list(pred_res) key = hashable_to_list(key) candidates = [] - for address_num in range(len(flatten(pred_res)) + 1): if address_num == 0: if check_equal(self._logic_forward(pred_res, multiple_predictions), key, self.max_err): @@ -183,23 +174,17 @@ class KBBase(ABC): else: new_candidates = self._address(address_num, pred_res, key, multiple_predictions) candidates += new_candidates - if len(candidates) > 0: min_address_num = address_num break - if address_num >= max_address_num: return [] for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1): if address_num > max_address_num: - return candidates, min_address_num, address_num - 1 + return candidates new_candidates = self._address(address_num, pred_res, key, multiple_predictions) candidates += new_candidates - - # if self.abduce_cache: - # self._set_abduce_cache(pred_res, key, min_address_num, address_num, candidates, multiple_predictions) - return candidates def _dict_len(self, dic):