From cdb8bc2067b8c3cce5f24a587d341bb002e6e56f Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Thu, 24 Nov 2022 21:42:12 +0800 Subject: [PATCH] Update kb.py --- abducer/kb.py | 170 +++++++++++++++++++------------------------------- 1 file changed, 63 insertions(+), 107 deletions(-) diff --git a/abducer/kb.py b/abducer/kb.py index 2600f8b..63c4959 100644 --- a/abducer/kb.py +++ b/abducer/kb.py @@ -32,23 +32,32 @@ class KBBase(ABC): def abduce_candidates(self): pass - def filter_all_candidates(self, pred_res, all_candidates, max_address_num, require_more_address): - if len(all_candidates) == 0: - candidates = [] - min_address_num = 0 - address_num = 0 - else: - cost_list = self.hamming_dist(pred_res, all_candidates) - min_address_num = np.min(cost_list) - address_num = min(max_address_num, min_address_num + require_more_address) - idxs = np.where(cost_list <= address_num)[0] - candidates = [all_candidates[idx] for idx in idxs] + + def abduction(self, pred_res, key, max_address_num, require_more_address): + candidates = [] + for address_num in range(len(pred_res) + 1): + if address_num == 0: + if abs(self.logic_forward(pred_res) - key) <= 1e-3: + candidates.append(pred_res) + else: + new_candidates = self.address(address_num, pred_res, key) + candidates += new_candidates + + if len(candidates) > 0: + min_address_num = address_num + break + + if address_num >= max_address_num: + return [], 0, 0 + + for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1): + if address_num > max_address_num: + return candidates, min_address_num, address_num - 1 + new_candidates = self.address(address_num, pred_res, key) + candidates += new_candidates + return candidates, min_address_num, address_num - def hamming_dist(self, A, B): - B = np.array(B) - A = np.expand_dims(A, axis = 0).repeat(axis=0, repeats=(len(B))) - return np.sum(A != B, axis = 1) def __len__(self): pass @@ -90,66 +99,38 @@ class ClsKB(KBBase): pass def abduce_candidates(self, pred_res, key, max_address_num = -1, require_more_address = 0): - if max_address_num == -1: - max_address_num = len(pred_res) if self.GKB_flag: - all_candidates = self.get_candidates_GKB(key, len(pred_res)) - return self.filter_all_candidates(pred_res, all_candidates, max_address_num, require_more_address) + return self.abduce_from_GKB(pred_res, key, max_address_num, require_more_address) else: return self.abduction(pred_res, key, max_address_num, require_more_address) - def get_candidates_GKB(self, key, length = None): - if self.base == {}: + def abduce_from_GKB(self, pred_res, key, max_address_num, require_more_address): + if self.base == {} or len(pred_res) not in self.len_list: return [] - if key is None: - return self.get_all_candidates() + all_candidates = self.base[len(pred_res)][key] - if length is None: - length = list(self.base.keys()) - elif type(length) is int and length not in self.len_list: - return [] + if len(all_candidates) == 0: + candidates = [] + min_address_num = 0 + address_num = 0 else: - length = [length] - - return sum([self.base[l][key] for l in length], []) + cost_list = self.hamming_dist(pred_res, all_candidates) + min_address_num = np.min(cost_list) + address_num = min(max_address_num, min_address_num + require_more_address) + idxs = np.where(cost_list <= address_num)[0] + candidates = [all_candidates[idx] for idx in idxs] + + return candidates, min_address_num, address_num - def get_all_candidates(self): - if self.base == {}: - return [] - else: - return sum([sum(v.values(), []) for v in self.base.values()], []) + def hamming_dist(self, A, B): + B = np.array(B) + A = np.expand_dims(A, axis = 0).repeat(axis=0, repeats=(len(B))) + return np.sum(A != B, axis = 1) - - - - - def abduction(self, pred_res, key, max_address_num, require_more_address): - candidates = [] - for address_num in range(len(pred_res) + 1): - if address_num == 0: - if abs(self.logic_forward(pred_res) - key) <= 1e-3: - candidates.append(pred_res) - else: - new_candidates = self.address(address_num, pred_res, key) - candidates += new_candidates - - if len(candidates) > 0: - min_address_num = address_num - break - - if address_num >= max_address_num: - return [], 0, 0 - - for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1): - if address_num > max_address_num: - return candidates, min_address_num, address_num - 1 - new_candidates = self.address(address_num, pred_res, key) - candidates += new_candidates - return candidates, min_address_num, address_num def address(self, address_num, pred_res, key): new_candidates = [] @@ -214,9 +195,6 @@ class hwf_KB(ClsKB): return round(eval(''.join(formula)), 2) - - - class prolog_KB(KBBase): def __init__(self, pseudo_label_list): super().__init__() @@ -224,16 +202,29 @@ class prolog_KB(KBBase): self.prolog = pyswip.Prolog() for i in self.pseudo_label_list: self.prolog.assertz("pseudo_label(%s)" % i) - self.prolog_flag = True def logic_forward(self): pass - def abduce_candidates(self, pred_res, key, max_address_num = -1, require_more_address = 0): - if max_address_num == -1: - max_address_num = len(pred_res) - all_candidates = self.get_candidates_prolog(key) - return self.filter_all_candidates(pred_res, all_candidates, max_address_num, require_more_address) + def abduce_candidates(self, pred_res, key, max_address_num, require_more_address): + return self.abduction(pred_res, key, max_address_num, require_more_address) + + def address(self, address_num, pred_res, key): + new_candidates = [] + address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) + for address_idx in address_idx_list: + query_string = "addition(" + for idx, i in enumerate(pred_res): + tmp = 'Z' + str(idx) + ',' if idx in address_idx else str(i) + ',' + query_string += tmp + query_string += "%s)." + abduce_c = [list(z.values()) for z in list(self.prolog.query(query_string % key))] + for c in abduce_c: + candidate = pred_res.copy() + for i, idx in enumerate(address_idx): + candidate[idx] = c[i] + new_candidates.append(candidate) + return new_candidates class add_prolog_KB(prolog_KB): @@ -301,42 +292,7 @@ class RegKB(KBBase): import time if __name__ == "__main__": - # With ground KB - kb = add_KB(GKB_flag = True) - print('len(kb):', len(kb)) - res = kb.get_candidates_GKB(0) - print(res) - res = kb.get_candidates_GKB(18) - print(res) - res = kb.get_candidates_GKB(18) - print(res) - res = kb.get_candidates_GKB(16) - print(res) - print() - - # Without ground KB - kb = add_KB() - print('len(kb):', len(kb)) - print() - - # Prolog - kb = add_prolog_KB() - print(kb.logic_forward([3, 4])) - res = kb.get_candidates_prolog(16) - print(res) - - start = time.time() - kb = hwf_KB(GKB_flag = True, len_list = [1, 3, 5]) - print(time.time() - start) - print('len(kb):', len(kb)) - res = kb.get_candidates_GKB(2, length = 1) - print(res) - res = kb.get_candidates_GKB(1, length = 3) - print(res) - res = kb.get_candidates_GKB(3.67, length = 5) - print(res) - print() - + pass # X = ["1+1", "0+1", "1+0", "2+0", "1+0+1"]