diff --git a/abducer/kb.py b/abducer/kb.py index db2bf6f..9ae8efd 100644 --- a/abducer/kb.py +++ b/abducer/kb.py @@ -22,7 +22,7 @@ import pyswip class KBBase(ABC): - def __init__(self): + def __init__(self, pseudo_label_list = None): pass @abstractmethod @@ -34,16 +34,16 @@ class KBBase(ABC): pass - def address(self, address_num, pred_res, key): + def address(self, address_num, pred_res, key, multiple_predictions = False): new_candidates = [] address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) for address_idx in address_idx_list: - candidates = self.address_by_idx(pred_res, key, address_idx) + candidates = self.address_by_idx(pred_res, key, address_idx, multiple_predictions) new_candidates += candidates return new_candidates - def abduction(self, pred_res, key, max_address_num, require_more_address): + def abduction(self, pred_res, key, max_address_num, require_more_address, multiple_predictions = False): candidates = [] for address_num in range(len(pred_res) + 1): @@ -51,7 +51,7 @@ class KBBase(ABC): if abs(self.logic_forward(pred_res) - key) <= 1e-3: candidates.append(pred_res) else: - new_candidates = self.address(address_num, pred_res, key) + new_candidates = self.address(address_num, pred_res, key, multiple_predictions) candidates += new_candidates if len(candidates) > 0: @@ -64,12 +64,30 @@ class KBBase(ABC): for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1): if address_num > max_address_num: return candidates, min_address_num, address_num - 1 - new_candidates = self.address(address_num, pred_res, key) + new_candidates = self.address(address_num, pred_res, key, multiple_predictions) candidates += new_candidates return candidates, min_address_num, address_num + # for multiple predictions, modify from `learn_add.py` + def flatten(self, l): + return [item for sublist in l for item in sublist] + + # for multiple predictions, modify from `learn_add.py` + def reform_ids(self, flatten_pred_res, save_pred_res): + re = [] + i = 0 + for e in save_pred_res: + j = 0 + ids = [] + while j < len(e): + ids.append(flatten_pred_res[i + j]) + j += 1 + re.append(ids) + i = i + j + return re + def __len__(self): pass @@ -82,7 +100,6 @@ class ClsKB(KBBase): self.prolog_flag = False if GKB_flag: - # self.base = np.load('abducer/hwf.npy', allow_pickle=True).item() self.base = {} X, Y = self.get_GKB(self.pseudo_label_list, self.len_list) for x, y in zip(X, Y): @@ -109,11 +126,11 @@ class ClsKB(KBBase): def logic_forward(self): pass - def abduce_candidates(self, pred_res, key, max_address_num = -1, require_more_address = 0): + def abduce_candidates(self, pred_res, key, max_address_num = -1, require_more_address = 0, multiple_predictions = False): if self.GKB_flag: return self.abduce_from_GKB(pred_res, key, max_address_num, require_more_address) else: - return self.abduction(pred_res, key, max_address_num, require_more_address) + return self.abduction(pred_res, key, max_address_num, require_more_address, multiple_predictions) @@ -142,14 +159,23 @@ class ClsKB(KBBase): return candidates, min_address_num, address_num - def address_by_idx(self, pred_res, key, address_idx): + def address_by_idx(self, pred_res, key, address_idx, multiple_predictions = False): candidates = [] abduce_c = self.all_address_candidate_dict[len(address_idx)] + + if multiple_predictions: + save_pred_res = pred_res + pred_res = self.flatten(pred_res) + for c in abduce_c: candidate = pred_res.copy() for i, idx in enumerate(address_idx): candidate[idx] = c[i] - if self.logic_forward(candidate) == key: + + if multiple_predictions: + candidate = self.reform_ids(candidate, save_pred_res) + + if self.logic_forward(candidate) == key: candidates.append(candidate) return candidates @@ -176,7 +202,7 @@ class add_KB(ClsKB): def logic_forward(self, nums): return sum(nums) -class hwf_KB(ClsKB): +class HWF_KB(ClsKB): def __init__(self, GKB_flag = False, \ pseudo_label_list = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], \ len_list = [1, 3, 5, 7]): @@ -205,62 +231,99 @@ class prolog_KB(KBBase): super().__init__() self.pseudo_label_list = pseudo_label_list self.prolog = pyswip.Prolog() - for i in self.pseudo_label_list: - self.prolog.assertz("pseudo_label(%s)" % i) + def logic_forward(self): pass - def abduce_candidates(self, pred_res, key, max_address_num, require_more_address): - return self.abduction(pred_res, key, max_address_num, require_more_address) + def abduce_candidates(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): + return self.abduction(pred_res, key, max_address_num, require_more_address, multiple_predictions) - def address_by_idx(self, pred_res, key, address_idx, verbose=True): + def address_by_idx(self, pred_res, key, address_idx, multiple_predictions = False): candidates = [] - query_string = self.get_query_string(pred_res, address_idx) - abduce_c = [list(z.values()) for z in list(self.prolog.query(query_string % key))] + + if not multiple_predictions: + query_string = self.get_query_string(pred_res, key, address_idx) + else: + query_string = self.get_query_string_need_flatten(pred_res, key, address_idx) + + if multiple_predictions: + save_pred_res = pred_res + pred_res = self.flatten(pred_res) + + abduce_c = [list(z.values()) for z in list(self.prolog.query(query_string))] for c in abduce_c: candidate = pred_res.copy() for i, idx in enumerate(address_idx): candidate[idx] = c[i] + + if multiple_predictions: + candidate = self.reform_ids(candidate, save_pred_res) + candidates.append(candidate) - return candidates - + return candidates class add_prolog_KB(prolog_KB): def __init__(self, pseudo_label_list = list(range(10))): super().__init__(pseudo_label_list) + for i in self.pseudo_label_list: + self.prolog.assertz("pseudo_label(%s)" % i) self.prolog.assertz("addition(Z1, Z2, Res) :- pseudo_label(Z1), pseudo_label(Z2), Res is Z1+Z2") def logic_forward(self, nums): return list(self.prolog.query("addition(%s, %s, Res)." %(nums[0], nums[1])))[0]['Res'] - def get_query_string(self, pred_res, address_idx): + def get_query_string(self, pred_res, key, address_idx): query_string = "addition(" for idx, i in enumerate(pred_res): tmp = 'Z' + str(idx) + ',' if idx in address_idx else str(i) + ',' query_string += tmp - query_string += "%s)." + query_string += "%s)." % key return query_string -class hed_prolog_KB(prolog_KB): - def __init__(self, pseudo_label_list = list(range(10))): +class HED_prolog_KB(prolog_KB): + def __init__(self, pseudo_label_list = [0, 1, '+', '=']): super().__init__(pseudo_label_list) - self.prolog.assertz("addition(Z1, Z2, Res) :- pseudo_label(Z1), pseudo_label(Z2), Res is Z1+Z2") + self.prolog.consult('../datasets/hed/learn_add.pl') - def logic_forward(self, nums): - return list(self.prolog.query("addition(%s, %s, Res)." %(nums[0], nums[1])))[0]['Res'] + # corresponding to `con_sol is not None` in `consistent_score_mapped` within `learn_add.py` + def logic_forward(self, exs): + return len(list(self.prolog.query("abduce_consistent_insts(%s)." % exs))) != 0 - def get_query_string(self, pred_res, address_idx): - query_string = "addition(" - for idx, i in enumerate(pred_res): - tmp = 'Z' + str(idx) + ',' if idx in address_idx else str(i) + ',' - query_string += tmp - query_string += "%s)." - return query_string + + def get_query_string_need_flatten(self, pred_res, key, address_idx): + # flatten + flatten_pred_res = self.flatten(pred_res) + # add variables for prolog + for idx in range(len(flatten_pred_res)): + if idx in address_idx: + flatten_pred_res[idx] = 'X' + str(idx) + # unflatten + new_pred_res = self.reform_ids(flatten_pred_res, pred_res) + + query_string = "abduce_consistent_insts(%s)." % new_pred_res + return query_string.replace("'", "").replace("+", "'+'").replace("=", "'='") + + + + def consist_rule(self, exs, rules): + rule_str = "%s" % rules + rule_str = rule_str.replace("'", "") + return len(list(self.prolog.query("consistent_inst_feature(%s, %s)." %(exs, rule_str)))) != 0 + + def abduce_rules(self, pred_res): + prolog_rules = list(self.prolog.query("consistent_inst_feature(%s, X)." % pred_res))[0]['X'] + rules = [] + for rule in prolog_rules: + rules.append(rule.value) + return rules + + # def consist_rules(self, pred_res, rules): + class RegKB(KBBase):