From 650c172d61e4b9ef6c468ae07e2826086c6f3dec Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Thu, 2 Mar 2023 09:24:41 +0800 Subject: [PATCH] Update kb.py --- abducer/kb.py | 27 ++++++++------------------- 1 file changed, 8 insertions(+), 19 deletions(-) diff --git a/abducer/kb.py b/abducer/kb.py index b78ab64..b9beb3f 100644 --- a/abducer/kb.py +++ b/abducer/kb.py @@ -21,7 +21,7 @@ sys.path.append("..") from collections import defaultdict from itertools import product, combinations -from utils.utils import flatten, reform_idx, hamming_dist, check_is_equal +from utils.utils import flatten, reform_idx, hamming_dist, check_equal from multiprocessing import Pool @@ -56,12 +56,12 @@ class KBBase(ABC): new_candidates += candidates return new_candidates - def _abduce_by_abduction(self, pred_res, key, max_address_num, require_more_address=0, multiple_predictions=False): + def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address=0, multiple_predictions=False): candidates = [] for address_num in range(len(flatten(pred_res)) + 1): if address_num == 0: - if check_is_equal(pred_res, key): + if check_equal(self.logic_forward(pred_res), key): candidates.append(pred_res) else: new_candidates = self._address(address_num, pred_res, key, multiple_predictions) @@ -114,19 +114,8 @@ class ClsKB(KBBase): XY_list.append((x, y)) return XY_list - # Parallel get GKB + # Parallel _get_GKB def _get_GKB(self): - # all_X = [] - # for length in len_list: - # all_X += list(product(self.pseudo_label_list, repeat = length)) - - # X, Y = [], [] - # for x in all_X: - # y = self.logic_forward(x) - # if y != np.inf: - # X.append(x) - # Y.append(y) - X, Y = [], [] for length in self.len_list: arg_list = [] @@ -148,11 +137,11 @@ class ClsKB(KBBase): def abduce_candidates(self, pred_res, key, max_address_num=-1, require_more_address=0, multiple_predictions=False): if self.GKB_flag: - return self._abduce_from_GKB(pred_res, key, max_address_num, require_more_address) + return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address) else: - return self._abduce_by_abduction(pred_res, key, max_address_num, require_more_address, multiple_predictions) + return self._abduce_by_search(pred_res, key, max_address_num, require_more_address, multiple_predictions) - def _abduce_from_GKB(self, pred_res, key, max_address_num, require_more_address): + def _abduce_by_GKB(self, pred_res, key, max_address_num, require_more_address): if self.base == {} or len(pred_res) not in self.len_list: return [] @@ -260,7 +249,7 @@ class prolog_KB(KBBase): pass def abduce_candidates(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): - return self._abduce_by_abduction(pred_res, key, max_address_num, require_more_address, multiple_predictions) + return self._abduce_by_search(pred_res, key, max_address_num, require_more_address, multiple_predictions) def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False): candidates = []