| @@ -15,8 +15,12 @@ import bisect | |||
| import copy | |||
| import numpy as np | |||
| import sys | |||
| sys.path.append("..") | |||
| from collections import defaultdict | |||
| from itertools import product, combinations | |||
| from utils.utils import _flatten, _reform_ids, _hamming_dist | |||
| import pyswip | |||
| @@ -39,25 +43,21 @@ class KBBase(ABC): | |||
| if not multiple_predictions: | |||
| address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) | |||
| else: | |||
| address_idx_list = list(combinations(list(range(len(self.flatten(pred_res)))), address_num)) | |||
| address_idx_list = list(combinations(list(range(len(_flatten(pred_res)))), address_num)) | |||
| for address_idx in address_idx_list: | |||
| candidates = self.address_by_idx(pred_res, key, address_idx, multiple_predictions) | |||
| new_candidates += candidates | |||
| return new_candidates | |||
| def correct_result(self, pred_res, key): | |||
| if type(key) != bool: | |||
| return abs(self.logic_forward(pred_res) - key) <= 1e-3 | |||
| else: | |||
| return self.logic_forward(pred_res) | |||
| def abduction(self, pred_res, key, max_address_num, require_more_address, multiple_predictions = False): | |||
| candidates = [] | |||
| for address_num in range(len(pred_res) + 1): | |||
| if address_num == 0: | |||
| if self.correct_result(pred_res, key): | |||
| if abs(self.logic_forward(pred_res) - key) <= 1e-3: | |||
| candidates.append(pred_res) | |||
| else: | |||
| new_candidates = self.address(address_num, pred_res, key, multiple_predictions) | |||
| @@ -79,23 +79,7 @@ class KBBase(ABC): | |||
| return candidates, min_address_num, address_num | |||
| # for multiple predictions, modify from `learn_add.py` | |||
| def flatten(self, l): | |||
| return [item for sublist in l for item in sublist] | |||
| # for multiple predictions, modify from `learn_add.py` | |||
| def reform_ids(self, flatten_pred_res, save_pred_res): | |||
| re = [] | |||
| i = 0 | |||
| for e in save_pred_res: | |||
| j = 0 | |||
| ids = [] | |||
| while j < len(e): | |||
| ids.append(flatten_pred_res[i + j]) | |||
| j += 1 | |||
| re.append(ids) | |||
| i = i + j | |||
| return re | |||
| def __len__(self): | |||
| pass | |||
| @@ -110,7 +94,7 @@ class ClsKB(KBBase): | |||
| if GKB_flag: | |||
| self.base = {} | |||
| X, Y = self.get_GKB(self.pseudo_label_list, self.len_list) | |||
| X, Y = self._get_GKB(self.pseudo_label_list, self.len_list) | |||
| for x, y in zip(X, Y): | |||
| self.base.setdefault(len(x), defaultdict(list))[y].append(x) | |||
| else: | |||
| @@ -118,7 +102,7 @@ class ClsKB(KBBase): | |||
| for address_num in range(max(self.len_list) + 1): | |||
| self.all_address_candidate_dict[address_num] = list(product(self.pseudo_label_list, repeat = address_num)) | |||
| def get_GKB(self, pseudo_label_list, len_list): | |||
| def _get_GKB(self, pseudo_label_list, len_list): | |||
| all_X = [] | |||
| for len in len_list: | |||
| all_X += list(product(pseudo_label_list, repeat = len)) | |||
| @@ -142,12 +126,6 @@ class ClsKB(KBBase): | |||
| return self.abduction(pred_res, key, max_address_num, require_more_address, multiple_predictions) | |||
| def hamming_dist(self, A, B): | |||
| B = np.array(B) | |||
| A = np.expand_dims(A, axis = 0).repeat(axis=0, repeats=(len(B))) | |||
| return np.sum(A != B, axis = 1) | |||
| def abduce_from_GKB(self, pred_res, key, max_address_num, require_more_address): | |||
| if self.base == {} or len(pred_res) not in self.len_list: | |||
| return [] | |||
| @@ -159,7 +137,7 @@ class ClsKB(KBBase): | |||
| min_address_num = 0 | |||
| address_num = 0 | |||
| else: | |||
| cost_list = self.hamming_dist(pred_res, all_candidates) | |||
| cost_list = _hamming_dist(pred_res, all_candidates) | |||
| min_address_num = np.min(cost_list) | |||
| address_num = min(max_address_num, min_address_num + require_more_address) | |||
| idxs = np.where(cost_list <= address_num)[0] | |||
| @@ -174,7 +152,7 @@ class ClsKB(KBBase): | |||
| if multiple_predictions: | |||
| save_pred_res = pred_res | |||
| pred_res = self.flatten(pred_res) | |||
| pred_res = _flatten(pred_res) | |||
| for c in abduce_c: | |||
| candidate = pred_res.copy() | |||
| @@ -182,7 +160,7 @@ class ClsKB(KBBase): | |||
| candidate[idx] = c[i] | |||
| if multiple_predictions: | |||
| candidate = self.reform_ids(candidate, save_pred_res) | |||
| candidate = _reform_ids(candidate, save_pred_res) | |||
| if self.logic_forward(candidate) == key: | |||
| candidates.append(candidate) | |||
| @@ -252,15 +230,15 @@ class prolog_KB(KBBase): | |||
| def address_by_idx(self, pred_res, key, address_idx, multiple_predictions = False): | |||
| candidates = [] | |||
| # print(address_idx) | |||
| if not multiple_predictions: | |||
| query_string = self.get_query_string(pred_res, key, address_idx) | |||
| else: | |||
| query_string = self.get_query_string_need_flatten(pred_res, key, address_idx) | |||
| query_string = self.get_query_string_need__flatten(pred_res, key, address_idx) | |||
| if multiple_predictions: | |||
| save_pred_res = pred_res | |||
| pred_res = self.flatten(pred_res) | |||
| pred_res = _flatten(pred_res) | |||
| abduce_c = [list(z.values()) for z in list(self.prolog.query(query_string))] | |||
| for c in abduce_c: | |||
| @@ -269,7 +247,7 @@ class prolog_KB(KBBase): | |||
| candidate[idx] = c[i] | |||
| if multiple_predictions: | |||
| candidate = self.reform_ids(candidate, save_pred_res) | |||
| candidate = _reform_ids(candidate, save_pred_res) | |||
| candidates.append(candidate) | |||
| return candidates | |||
| @@ -297,22 +275,22 @@ class add_prolog_KB(prolog_KB): | |||
| class HED_prolog_KB(prolog_KB): | |||
| def __init__(self, pseudo_label_list = [0, 1, '+', '=']): | |||
| super().__init__(pseudo_label_list) | |||
| self.prolog.consult('../datasets/hed/learn_add.pl') | |||
| self.prolog.consult('./datasets/hed/learn_add.pl') | |||
| # corresponding to `con_sol is not None` in `consistent_score_mapped` within `learn_add.py` | |||
| def logic_forward(self, exs): | |||
| return len(list(self.prolog.query("abduce_consistent_insts(%s)." % exs))) != 0 | |||
| def get_query_string_need_flatten(self, pred_res, key, address_idx): | |||
| # flatten | |||
| flatten_pred_res = self.flatten(pred_res) | |||
| def get_query_string_need__flatten(self, pred_res, key, address_idx): | |||
| # _flatten | |||
| _flatten_pred_res = _flatten(pred_res) | |||
| # add variables for prolog | |||
| for idx in range(len(flatten_pred_res)): | |||
| for idx in range(len(_flatten_pred_res)): | |||
| if idx in address_idx: | |||
| flatten_pred_res[idx] = 'X' + str(idx) | |||
| # unflatten | |||
| new_pred_res = self.reform_ids(flatten_pred_res, pred_res) | |||
| _flatten_pred_res[idx] = 'X' + str(idx) | |||
| # un_flatten | |||
| new_pred_res = _reform_ids(_flatten_pred_res, pred_res) | |||
| query_string = "abduce_consistent_insts(%s)." % new_pred_res | |||
| return query_string.replace("'", "").replace("+", "'+'").replace("=", "'='") | |||