From d4ef2ca8d9730e03a0d5307977e273ff226f021f Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Wed, 7 Dec 2022 18:49:23 +0800 Subject: [PATCH] Update kb.py --- abducer/kb.py | 72 ++++++++++++++++++--------------------------------- 1 file changed, 25 insertions(+), 47 deletions(-) diff --git a/abducer/kb.py b/abducer/kb.py index 7ade7f9..f44ffec 100644 --- a/abducer/kb.py +++ b/abducer/kb.py @@ -15,8 +15,12 @@ import bisect import copy import numpy as np +import sys +sys.path.append("..") + from collections import defaultdict from itertools import product, combinations +from utils.utils import _flatten, _reform_ids, _hamming_dist import pyswip @@ -39,25 +43,21 @@ class KBBase(ABC): if not multiple_predictions: address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) else: - address_idx_list = list(combinations(list(range(len(self.flatten(pred_res)))), address_num)) + address_idx_list = list(combinations(list(range(len(_flatten(pred_res)))), address_num)) for address_idx in address_idx_list: candidates = self.address_by_idx(pred_res, key, address_idx, multiple_predictions) new_candidates += candidates return new_candidates - def correct_result(self, pred_res, key): - if type(key) != bool: - return abs(self.logic_forward(pred_res) - key) <= 1e-3 - else: - return self.logic_forward(pred_res) + def abduction(self, pred_res, key, max_address_num, require_more_address, multiple_predictions = False): candidates = [] for address_num in range(len(pred_res) + 1): if address_num == 0: - if self.correct_result(pred_res, key): + if abs(self.logic_forward(pred_res) - key) <= 1e-3: candidates.append(pred_res) else: new_candidates = self.address(address_num, pred_res, key, multiple_predictions) @@ -79,23 +79,7 @@ class KBBase(ABC): return candidates, min_address_num, address_num - # for multiple predictions, modify from `learn_add.py` - def flatten(self, l): - return [item for sublist in l for item in sublist] - - # for multiple predictions, modify from `learn_add.py` - def reform_ids(self, flatten_pred_res, save_pred_res): - re = [] - i = 0 - for e in save_pred_res: - j = 0 - ids = [] - while j < len(e): - ids.append(flatten_pred_res[i + j]) - j += 1 - re.append(ids) - i = i + j - return re + def __len__(self): pass @@ -110,7 +94,7 @@ class ClsKB(KBBase): if GKB_flag: self.base = {} - X, Y = self.get_GKB(self.pseudo_label_list, self.len_list) + X, Y = self._get_GKB(self.pseudo_label_list, self.len_list) for x, y in zip(X, Y): self.base.setdefault(len(x), defaultdict(list))[y].append(x) else: @@ -118,7 +102,7 @@ class ClsKB(KBBase): for address_num in range(max(self.len_list) + 1): self.all_address_candidate_dict[address_num] = list(product(self.pseudo_label_list, repeat = address_num)) - def get_GKB(self, pseudo_label_list, len_list): + def _get_GKB(self, pseudo_label_list, len_list): all_X = [] for len in len_list: all_X += list(product(pseudo_label_list, repeat = len)) @@ -142,12 +126,6 @@ class ClsKB(KBBase): return self.abduction(pred_res, key, max_address_num, require_more_address, multiple_predictions) - - def hamming_dist(self, A, B): - B = np.array(B) - A = np.expand_dims(A, axis = 0).repeat(axis=0, repeats=(len(B))) - return np.sum(A != B, axis = 1) - def abduce_from_GKB(self, pred_res, key, max_address_num, require_more_address): if self.base == {} or len(pred_res) not in self.len_list: return [] @@ -159,7 +137,7 @@ class ClsKB(KBBase): min_address_num = 0 address_num = 0 else: - cost_list = self.hamming_dist(pred_res, all_candidates) + cost_list = _hamming_dist(pred_res, all_candidates) min_address_num = np.min(cost_list) address_num = min(max_address_num, min_address_num + require_more_address) idxs = np.where(cost_list <= address_num)[0] @@ -174,7 +152,7 @@ class ClsKB(KBBase): if multiple_predictions: save_pred_res = pred_res - pred_res = self.flatten(pred_res) + pred_res = _flatten(pred_res) for c in abduce_c: candidate = pred_res.copy() @@ -182,7 +160,7 @@ class ClsKB(KBBase): candidate[idx] = c[i] if multiple_predictions: - candidate = self.reform_ids(candidate, save_pred_res) + candidate = _reform_ids(candidate, save_pred_res) if self.logic_forward(candidate) == key: candidates.append(candidate) @@ -252,15 +230,15 @@ class prolog_KB(KBBase): def address_by_idx(self, pred_res, key, address_idx, multiple_predictions = False): candidates = [] - + # print(address_idx) if not multiple_predictions: query_string = self.get_query_string(pred_res, key, address_idx) else: - query_string = self.get_query_string_need_flatten(pred_res, key, address_idx) + query_string = self.get_query_string_need__flatten(pred_res, key, address_idx) if multiple_predictions: save_pred_res = pred_res - pred_res = self.flatten(pred_res) + pred_res = _flatten(pred_res) abduce_c = [list(z.values()) for z in list(self.prolog.query(query_string))] for c in abduce_c: @@ -269,7 +247,7 @@ class prolog_KB(KBBase): candidate[idx] = c[i] if multiple_predictions: - candidate = self.reform_ids(candidate, save_pred_res) + candidate = _reform_ids(candidate, save_pred_res) candidates.append(candidate) return candidates @@ -297,22 +275,22 @@ class add_prolog_KB(prolog_KB): class HED_prolog_KB(prolog_KB): def __init__(self, pseudo_label_list = [0, 1, '+', '=']): super().__init__(pseudo_label_list) - self.prolog.consult('../datasets/hed/learn_add.pl') + self.prolog.consult('./datasets/hed/learn_add.pl') # corresponding to `con_sol is not None` in `consistent_score_mapped` within `learn_add.py` def logic_forward(self, exs): return len(list(self.prolog.query("abduce_consistent_insts(%s)." % exs))) != 0 - def get_query_string_need_flatten(self, pred_res, key, address_idx): - # flatten - flatten_pred_res = self.flatten(pred_res) + def get_query_string_need__flatten(self, pred_res, key, address_idx): + # _flatten + _flatten_pred_res = _flatten(pred_res) # add variables for prolog - for idx in range(len(flatten_pred_res)): + for idx in range(len(_flatten_pred_res)): if idx in address_idx: - flatten_pred_res[idx] = 'X' + str(idx) - # unflatten - new_pred_res = self.reform_ids(flatten_pred_res, pred_res) + _flatten_pred_res[idx] = 'X' + str(idx) + # un_flatten + new_pred_res = _reform_ids(_flatten_pred_res, pred_res) query_string = "abduce_consistent_insts(%s)." % new_pred_res return query_string.replace("'", "").replace("+", "'+'").replace("=", "'='")