| @@ -22,7 +22,7 @@ import pyswip | |||
| class KBBase(ABC): | |||
| def __init__(self): | |||
| def __init__(self, pseudo_label_list = None): | |||
| pass | |||
| @abstractmethod | |||
| @@ -34,16 +34,16 @@ class KBBase(ABC): | |||
| pass | |||
| def address(self, address_num, pred_res, key): | |||
| def address(self, address_num, pred_res, key, multiple_predictions = False): | |||
| new_candidates = [] | |||
| address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) | |||
| for address_idx in address_idx_list: | |||
| candidates = self.address_by_idx(pred_res, key, address_idx) | |||
| candidates = self.address_by_idx(pred_res, key, address_idx, multiple_predictions) | |||
| new_candidates += candidates | |||
| return new_candidates | |||
| def abduction(self, pred_res, key, max_address_num, require_more_address): | |||
| def abduction(self, pred_res, key, max_address_num, require_more_address, multiple_predictions = False): | |||
| candidates = [] | |||
| for address_num in range(len(pred_res) + 1): | |||
| @@ -51,7 +51,7 @@ class KBBase(ABC): | |||
| if abs(self.logic_forward(pred_res) - key) <= 1e-3: | |||
| candidates.append(pred_res) | |||
| else: | |||
| new_candidates = self.address(address_num, pred_res, key) | |||
| new_candidates = self.address(address_num, pred_res, key, multiple_predictions) | |||
| candidates += new_candidates | |||
| if len(candidates) > 0: | |||
| @@ -64,12 +64,30 @@ class KBBase(ABC): | |||
| for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1): | |||
| if address_num > max_address_num: | |||
| return candidates, min_address_num, address_num - 1 | |||
| new_candidates = self.address(address_num, pred_res, key) | |||
| new_candidates = self.address(address_num, pred_res, key, multiple_predictions) | |||
| candidates += new_candidates | |||
| return candidates, min_address_num, address_num | |||
| # for multiple predictions, modify from `learn_add.py` | |||
| def flatten(self, l): | |||
| return [item for sublist in l for item in sublist] | |||
| # for multiple predictions, modify from `learn_add.py` | |||
| def reform_ids(self, flatten_pred_res, save_pred_res): | |||
| re = [] | |||
| i = 0 | |||
| for e in save_pred_res: | |||
| j = 0 | |||
| ids = [] | |||
| while j < len(e): | |||
| ids.append(flatten_pred_res[i + j]) | |||
| j += 1 | |||
| re.append(ids) | |||
| i = i + j | |||
| return re | |||
| def __len__(self): | |||
| pass | |||
| @@ -82,7 +100,6 @@ class ClsKB(KBBase): | |||
| self.prolog_flag = False | |||
| if GKB_flag: | |||
| # self.base = np.load('abducer/hwf.npy', allow_pickle=True).item() | |||
| self.base = {} | |||
| X, Y = self.get_GKB(self.pseudo_label_list, self.len_list) | |||
| for x, y in zip(X, Y): | |||
| @@ -109,11 +126,11 @@ class ClsKB(KBBase): | |||
| def logic_forward(self): | |||
| pass | |||
| def abduce_candidates(self, pred_res, key, max_address_num = -1, require_more_address = 0): | |||
| def abduce_candidates(self, pred_res, key, max_address_num = -1, require_more_address = 0, multiple_predictions = False): | |||
| if self.GKB_flag: | |||
| return self.abduce_from_GKB(pred_res, key, max_address_num, require_more_address) | |||
| else: | |||
| return self.abduction(pred_res, key, max_address_num, require_more_address) | |||
| return self.abduction(pred_res, key, max_address_num, require_more_address, multiple_predictions) | |||
| @@ -142,14 +159,23 @@ class ClsKB(KBBase): | |||
| return candidates, min_address_num, address_num | |||
| def address_by_idx(self, pred_res, key, address_idx): | |||
| def address_by_idx(self, pred_res, key, address_idx, multiple_predictions = False): | |||
| candidates = [] | |||
| abduce_c = self.all_address_candidate_dict[len(address_idx)] | |||
| if multiple_predictions: | |||
| save_pred_res = pred_res | |||
| pred_res = self.flatten(pred_res) | |||
| for c in abduce_c: | |||
| candidate = pred_res.copy() | |||
| for i, idx in enumerate(address_idx): | |||
| candidate[idx] = c[i] | |||
| if self.logic_forward(candidate) == key: | |||
| if multiple_predictions: | |||
| candidate = self.reform_ids(candidate, save_pred_res) | |||
| if self.logic_forward(candidate) == key: | |||
| candidates.append(candidate) | |||
| return candidates | |||
| @@ -176,7 +202,7 @@ class add_KB(ClsKB): | |||
| def logic_forward(self, nums): | |||
| return sum(nums) | |||
| class hwf_KB(ClsKB): | |||
| class HWF_KB(ClsKB): | |||
| def __init__(self, GKB_flag = False, \ | |||
| pseudo_label_list = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], \ | |||
| len_list = [1, 3, 5, 7]): | |||
| @@ -205,62 +231,99 @@ class prolog_KB(KBBase): | |||
| super().__init__() | |||
| self.pseudo_label_list = pseudo_label_list | |||
| self.prolog = pyswip.Prolog() | |||
| for i in self.pseudo_label_list: | |||
| self.prolog.assertz("pseudo_label(%s)" % i) | |||
| def logic_forward(self): | |||
| pass | |||
| def abduce_candidates(self, pred_res, key, max_address_num, require_more_address): | |||
| return self.abduction(pred_res, key, max_address_num, require_more_address) | |||
| def abduce_candidates(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): | |||
| return self.abduction(pred_res, key, max_address_num, require_more_address, multiple_predictions) | |||
| def address_by_idx(self, pred_res, key, address_idx, verbose=True): | |||
| def address_by_idx(self, pred_res, key, address_idx, multiple_predictions = False): | |||
| candidates = [] | |||
| query_string = self.get_query_string(pred_res, address_idx) | |||
| abduce_c = [list(z.values()) for z in list(self.prolog.query(query_string % key))] | |||
| if not multiple_predictions: | |||
| query_string = self.get_query_string(pred_res, key, address_idx) | |||
| else: | |||
| query_string = self.get_query_string_need_flatten(pred_res, key, address_idx) | |||
| if multiple_predictions: | |||
| save_pred_res = pred_res | |||
| pred_res = self.flatten(pred_res) | |||
| abduce_c = [list(z.values()) for z in list(self.prolog.query(query_string))] | |||
| for c in abduce_c: | |||
| candidate = pred_res.copy() | |||
| for i, idx in enumerate(address_idx): | |||
| candidate[idx] = c[i] | |||
| if multiple_predictions: | |||
| candidate = self.reform_ids(candidate, save_pred_res) | |||
| candidates.append(candidate) | |||
| return candidates | |||
| return candidates | |||
| class add_prolog_KB(prolog_KB): | |||
| def __init__(self, pseudo_label_list = list(range(10))): | |||
| super().__init__(pseudo_label_list) | |||
| for i in self.pseudo_label_list: | |||
| self.prolog.assertz("pseudo_label(%s)" % i) | |||
| self.prolog.assertz("addition(Z1, Z2, Res) :- pseudo_label(Z1), pseudo_label(Z2), Res is Z1+Z2") | |||
| def logic_forward(self, nums): | |||
| return list(self.prolog.query("addition(%s, %s, Res)." %(nums[0], nums[1])))[0]['Res'] | |||
| def get_query_string(self, pred_res, address_idx): | |||
| def get_query_string(self, pred_res, key, address_idx): | |||
| query_string = "addition(" | |||
| for idx, i in enumerate(pred_res): | |||
| tmp = 'Z' + str(idx) + ',' if idx in address_idx else str(i) + ',' | |||
| query_string += tmp | |||
| query_string += "%s)." | |||
| query_string += "%s)." % key | |||
| return query_string | |||
| class hed_prolog_KB(prolog_KB): | |||
| def __init__(self, pseudo_label_list = list(range(10))): | |||
| class HED_prolog_KB(prolog_KB): | |||
| def __init__(self, pseudo_label_list = [0, 1, '+', '=']): | |||
| super().__init__(pseudo_label_list) | |||
| self.prolog.assertz("addition(Z1, Z2, Res) :- pseudo_label(Z1), pseudo_label(Z2), Res is Z1+Z2") | |||
| self.prolog.consult('../datasets/hed/learn_add.pl') | |||
| def logic_forward(self, nums): | |||
| return list(self.prolog.query("addition(%s, %s, Res)." %(nums[0], nums[1])))[0]['Res'] | |||
| # corresponding to `con_sol is not None` in `consistent_score_mapped` within `learn_add.py` | |||
| def logic_forward(self, exs): | |||
| return len(list(self.prolog.query("abduce_consistent_insts(%s)." % exs))) != 0 | |||
| def get_query_string(self, pred_res, address_idx): | |||
| query_string = "addition(" | |||
| for idx, i in enumerate(pred_res): | |||
| tmp = 'Z' + str(idx) + ',' if idx in address_idx else str(i) + ',' | |||
| query_string += tmp | |||
| query_string += "%s)." | |||
| return query_string | |||
| def get_query_string_need_flatten(self, pred_res, key, address_idx): | |||
| # flatten | |||
| flatten_pred_res = self.flatten(pred_res) | |||
| # add variables for prolog | |||
| for idx in range(len(flatten_pred_res)): | |||
| if idx in address_idx: | |||
| flatten_pred_res[idx] = 'X' + str(idx) | |||
| # unflatten | |||
| new_pred_res = self.reform_ids(flatten_pred_res, pred_res) | |||
| query_string = "abduce_consistent_insts(%s)." % new_pred_res | |||
| return query_string.replace("'", "").replace("+", "'+'").replace("=", "'='") | |||
| def consist_rule(self, exs, rules): | |||
| rule_str = "%s" % rules | |||
| rule_str = rule_str.replace("'", "") | |||
| return len(list(self.prolog.query("consistent_inst_feature(%s, %s)." %(exs, rule_str)))) != 0 | |||
| def abduce_rules(self, pred_res): | |||
| prolog_rules = list(self.prolog.query("consistent_inst_feature(%s, X)." % pred_res))[0]['X'] | |||
| rules = [] | |||
| for rule in prolog_rules: | |||
| rules.append(rule.value) | |||
| return rules | |||
| # def consist_rules(self, pred_res, rules): | |||
| class RegKB(KBBase): | |||