| @@ -113,7 +113,6 @@ class KBBase(ABC): | |||
| min_address_num = 0 | |||
| all_candidates_save = [] | |||
| cost_list_save = [] | |||
| for p_res, k in zip(pred_res, key): | |||
| if len(p_res) not in self.len_list: | |||
| return [] | |||
| @@ -136,7 +135,6 @@ class KBBase(ABC): | |||
| def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False): | |||
| candidates = [] | |||
| abduce_c = product(self.pseudo_label_list, repeat=len(address_idx)) | |||
| if multiple_predictions: | |||
| save_pred_res = pred_res | |||
| pred_res = flatten(pred_res) | |||
| @@ -145,10 +143,8 @@ class KBBase(ABC): | |||
| candidate = pred_res.copy() | |||
| for i, idx in enumerate(address_idx): | |||
| candidate[idx] = c[i] | |||
| if multiple_predictions: | |||
| candidate = reform_idx(candidate, save_pred_res) | |||
| if check_equal(self._logic_forward(candidate, multiple_predictions), key, self.max_err): | |||
| candidates.append(candidate) | |||
| return candidates | |||
| @@ -167,15 +163,10 @@ class KBBase(ABC): | |||
| @lru_cache(maxsize=100) | |||
| def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): | |||
| # if self.abduce_cache: | |||
| # candidates = self._get_abduce_cache(pred_res, key, max_address_num, require_more_address, multiple_predictions) | |||
| # if candidates is not None: | |||
| # return candidates | |||
| pred_res = hashable_to_list(pred_res) | |||
| key = hashable_to_list(key) | |||
| candidates = [] | |||
| for address_num in range(len(flatten(pred_res)) + 1): | |||
| if address_num == 0: | |||
| if check_equal(self._logic_forward(pred_res, multiple_predictions), key, self.max_err): | |||
| @@ -183,23 +174,17 @@ class KBBase(ABC): | |||
| else: | |||
| new_candidates = self._address(address_num, pred_res, key, multiple_predictions) | |||
| candidates += new_candidates | |||
| if len(candidates) > 0: | |||
| min_address_num = address_num | |||
| break | |||
| if address_num >= max_address_num: | |||
| return [] | |||
| for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1): | |||
| if address_num > max_address_num: | |||
| return candidates, min_address_num, address_num - 1 | |||
| return candidates | |||
| new_candidates = self._address(address_num, pred_res, key, multiple_predictions) | |||
| candidates += new_candidates | |||
| # if self.abduce_cache: | |||
| # self._set_abduce_cache(pred_res, key, min_address_num, address_num, candidates, multiple_predictions) | |||
| return candidates | |||
| def _dict_len(self, dic): | |||