| @@ -22,7 +22,7 @@ import pyswip | |||||
| class KBBase(ABC): | class KBBase(ABC): | ||||
| def __init__(self): | |||||
| def __init__(self, pseudo_label_list = None): | |||||
| pass | pass | ||||
| @abstractmethod | @abstractmethod | ||||
| @@ -34,16 +34,16 @@ class KBBase(ABC): | |||||
| pass | pass | ||||
| def address(self, address_num, pred_res, key): | |||||
| def address(self, address_num, pred_res, key, multiple_predictions = False): | |||||
| new_candidates = [] | new_candidates = [] | ||||
| address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) | address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) | ||||
| for address_idx in address_idx_list: | for address_idx in address_idx_list: | ||||
| candidates = self.address_by_idx(pred_res, key, address_idx) | |||||
| candidates = self.address_by_idx(pred_res, key, address_idx, multiple_predictions) | |||||
| new_candidates += candidates | new_candidates += candidates | ||||
| return new_candidates | return new_candidates | ||||
| def abduction(self, pred_res, key, max_address_num, require_more_address): | |||||
| def abduction(self, pred_res, key, max_address_num, require_more_address, multiple_predictions = False): | |||||
| candidates = [] | candidates = [] | ||||
| for address_num in range(len(pred_res) + 1): | for address_num in range(len(pred_res) + 1): | ||||
| @@ -51,7 +51,7 @@ class KBBase(ABC): | |||||
| if abs(self.logic_forward(pred_res) - key) <= 1e-3: | if abs(self.logic_forward(pred_res) - key) <= 1e-3: | ||||
| candidates.append(pred_res) | candidates.append(pred_res) | ||||
| else: | else: | ||||
| new_candidates = self.address(address_num, pred_res, key) | |||||
| new_candidates = self.address(address_num, pred_res, key, multiple_predictions) | |||||
| candidates += new_candidates | candidates += new_candidates | ||||
| if len(candidates) > 0: | if len(candidates) > 0: | ||||
| @@ -64,12 +64,30 @@ class KBBase(ABC): | |||||
| for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1): | for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1): | ||||
| if address_num > max_address_num: | if address_num > max_address_num: | ||||
| return candidates, min_address_num, address_num - 1 | return candidates, min_address_num, address_num - 1 | ||||
| new_candidates = self.address(address_num, pred_res, key) | |||||
| new_candidates = self.address(address_num, pred_res, key, multiple_predictions) | |||||
| candidates += new_candidates | candidates += new_candidates | ||||
| return candidates, min_address_num, address_num | return candidates, min_address_num, address_num | ||||
| # for multiple predictions, modify from `learn_add.py` | |||||
| def flatten(self, l): | |||||
| return [item for sublist in l for item in sublist] | |||||
| # for multiple predictions, modify from `learn_add.py` | |||||
| def reform_ids(self, flatten_pred_res, save_pred_res): | |||||
| re = [] | |||||
| i = 0 | |||||
| for e in save_pred_res: | |||||
| j = 0 | |||||
| ids = [] | |||||
| while j < len(e): | |||||
| ids.append(flatten_pred_res[i + j]) | |||||
| j += 1 | |||||
| re.append(ids) | |||||
| i = i + j | |||||
| return re | |||||
| def __len__(self): | def __len__(self): | ||||
| pass | pass | ||||
| @@ -82,7 +100,6 @@ class ClsKB(KBBase): | |||||
| self.prolog_flag = False | self.prolog_flag = False | ||||
| if GKB_flag: | if GKB_flag: | ||||
| # self.base = np.load('abducer/hwf.npy', allow_pickle=True).item() | |||||
| self.base = {} | self.base = {} | ||||
| X, Y = self.get_GKB(self.pseudo_label_list, self.len_list) | X, Y = self.get_GKB(self.pseudo_label_list, self.len_list) | ||||
| for x, y in zip(X, Y): | for x, y in zip(X, Y): | ||||
| @@ -109,11 +126,11 @@ class ClsKB(KBBase): | |||||
| def logic_forward(self): | def logic_forward(self): | ||||
| pass | pass | ||||
| def abduce_candidates(self, pred_res, key, max_address_num = -1, require_more_address = 0): | |||||
| def abduce_candidates(self, pred_res, key, max_address_num = -1, require_more_address = 0, multiple_predictions = False): | |||||
| if self.GKB_flag: | if self.GKB_flag: | ||||
| return self.abduce_from_GKB(pred_res, key, max_address_num, require_more_address) | return self.abduce_from_GKB(pred_res, key, max_address_num, require_more_address) | ||||
| else: | else: | ||||
| return self.abduction(pred_res, key, max_address_num, require_more_address) | |||||
| return self.abduction(pred_res, key, max_address_num, require_more_address, multiple_predictions) | |||||
| @@ -142,14 +159,23 @@ class ClsKB(KBBase): | |||||
| return candidates, min_address_num, address_num | return candidates, min_address_num, address_num | ||||
| def address_by_idx(self, pred_res, key, address_idx): | |||||
| def address_by_idx(self, pred_res, key, address_idx, multiple_predictions = False): | |||||
| candidates = [] | candidates = [] | ||||
| abduce_c = self.all_address_candidate_dict[len(address_idx)] | abduce_c = self.all_address_candidate_dict[len(address_idx)] | ||||
| if multiple_predictions: | |||||
| save_pred_res = pred_res | |||||
| pred_res = self.flatten(pred_res) | |||||
| for c in abduce_c: | for c in abduce_c: | ||||
| candidate = pred_res.copy() | candidate = pred_res.copy() | ||||
| for i, idx in enumerate(address_idx): | for i, idx in enumerate(address_idx): | ||||
| candidate[idx] = c[i] | candidate[idx] = c[i] | ||||
| if self.logic_forward(candidate) == key: | |||||
| if multiple_predictions: | |||||
| candidate = self.reform_ids(candidate, save_pred_res) | |||||
| if self.logic_forward(candidate) == key: | |||||
| candidates.append(candidate) | candidates.append(candidate) | ||||
| return candidates | return candidates | ||||
| @@ -176,7 +202,7 @@ class add_KB(ClsKB): | |||||
| def logic_forward(self, nums): | def logic_forward(self, nums): | ||||
| return sum(nums) | return sum(nums) | ||||
| class hwf_KB(ClsKB): | |||||
| class HWF_KB(ClsKB): | |||||
| def __init__(self, GKB_flag = False, \ | def __init__(self, GKB_flag = False, \ | ||||
| pseudo_label_list = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], \ | pseudo_label_list = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], \ | ||||
| len_list = [1, 3, 5, 7]): | len_list = [1, 3, 5, 7]): | ||||
| @@ -205,62 +231,99 @@ class prolog_KB(KBBase): | |||||
| super().__init__() | super().__init__() | ||||
| self.pseudo_label_list = pseudo_label_list | self.pseudo_label_list = pseudo_label_list | ||||
| self.prolog = pyswip.Prolog() | self.prolog = pyswip.Prolog() | ||||
| for i in self.pseudo_label_list: | |||||
| self.prolog.assertz("pseudo_label(%s)" % i) | |||||
| def logic_forward(self): | def logic_forward(self): | ||||
| pass | pass | ||||
| def abduce_candidates(self, pred_res, key, max_address_num, require_more_address): | |||||
| return self.abduction(pred_res, key, max_address_num, require_more_address) | |||||
| def abduce_candidates(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): | |||||
| return self.abduction(pred_res, key, max_address_num, require_more_address, multiple_predictions) | |||||
| def address_by_idx(self, pred_res, key, address_idx, verbose=True): | |||||
| def address_by_idx(self, pred_res, key, address_idx, multiple_predictions = False): | |||||
| candidates = [] | candidates = [] | ||||
| query_string = self.get_query_string(pred_res, address_idx) | |||||
| abduce_c = [list(z.values()) for z in list(self.prolog.query(query_string % key))] | |||||
| if not multiple_predictions: | |||||
| query_string = self.get_query_string(pred_res, key, address_idx) | |||||
| else: | |||||
| query_string = self.get_query_string_need_flatten(pred_res, key, address_idx) | |||||
| if multiple_predictions: | |||||
| save_pred_res = pred_res | |||||
| pred_res = self.flatten(pred_res) | |||||
| abduce_c = [list(z.values()) for z in list(self.prolog.query(query_string))] | |||||
| for c in abduce_c: | for c in abduce_c: | ||||
| candidate = pred_res.copy() | candidate = pred_res.copy() | ||||
| for i, idx in enumerate(address_idx): | for i, idx in enumerate(address_idx): | ||||
| candidate[idx] = c[i] | candidate[idx] = c[i] | ||||
| if multiple_predictions: | |||||
| candidate = self.reform_ids(candidate, save_pred_res) | |||||
| candidates.append(candidate) | candidates.append(candidate) | ||||
| return candidates | |||||
| return candidates | |||||
| class add_prolog_KB(prolog_KB): | class add_prolog_KB(prolog_KB): | ||||
| def __init__(self, pseudo_label_list = list(range(10))): | def __init__(self, pseudo_label_list = list(range(10))): | ||||
| super().__init__(pseudo_label_list) | super().__init__(pseudo_label_list) | ||||
| for i in self.pseudo_label_list: | |||||
| self.prolog.assertz("pseudo_label(%s)" % i) | |||||
| self.prolog.assertz("addition(Z1, Z2, Res) :- pseudo_label(Z1), pseudo_label(Z2), Res is Z1+Z2") | self.prolog.assertz("addition(Z1, Z2, Res) :- pseudo_label(Z1), pseudo_label(Z2), Res is Z1+Z2") | ||||
| def logic_forward(self, nums): | def logic_forward(self, nums): | ||||
| return list(self.prolog.query("addition(%s, %s, Res)." %(nums[0], nums[1])))[0]['Res'] | return list(self.prolog.query("addition(%s, %s, Res)." %(nums[0], nums[1])))[0]['Res'] | ||||
| def get_query_string(self, pred_res, address_idx): | |||||
| def get_query_string(self, pred_res, key, address_idx): | |||||
| query_string = "addition(" | query_string = "addition(" | ||||
| for idx, i in enumerate(pred_res): | for idx, i in enumerate(pred_res): | ||||
| tmp = 'Z' + str(idx) + ',' if idx in address_idx else str(i) + ',' | tmp = 'Z' + str(idx) + ',' if idx in address_idx else str(i) + ',' | ||||
| query_string += tmp | query_string += tmp | ||||
| query_string += "%s)." | |||||
| query_string += "%s)." % key | |||||
| return query_string | return query_string | ||||
| class hed_prolog_KB(prolog_KB): | |||||
| def __init__(self, pseudo_label_list = list(range(10))): | |||||
| class HED_prolog_KB(prolog_KB): | |||||
| def __init__(self, pseudo_label_list = [0, 1, '+', '=']): | |||||
| super().__init__(pseudo_label_list) | super().__init__(pseudo_label_list) | ||||
| self.prolog.assertz("addition(Z1, Z2, Res) :- pseudo_label(Z1), pseudo_label(Z2), Res is Z1+Z2") | |||||
| self.prolog.consult('../datasets/hed/learn_add.pl') | |||||
| def logic_forward(self, nums): | |||||
| return list(self.prolog.query("addition(%s, %s, Res)." %(nums[0], nums[1])))[0]['Res'] | |||||
| # corresponding to `con_sol is not None` in `consistent_score_mapped` within `learn_add.py` | |||||
| def logic_forward(self, exs): | |||||
| return len(list(self.prolog.query("abduce_consistent_insts(%s)." % exs))) != 0 | |||||
| def get_query_string(self, pred_res, address_idx): | |||||
| query_string = "addition(" | |||||
| for idx, i in enumerate(pred_res): | |||||
| tmp = 'Z' + str(idx) + ',' if idx in address_idx else str(i) + ',' | |||||
| query_string += tmp | |||||
| query_string += "%s)." | |||||
| return query_string | |||||
| def get_query_string_need_flatten(self, pred_res, key, address_idx): | |||||
| # flatten | |||||
| flatten_pred_res = self.flatten(pred_res) | |||||
| # add variables for prolog | |||||
| for idx in range(len(flatten_pred_res)): | |||||
| if idx in address_idx: | |||||
| flatten_pred_res[idx] = 'X' + str(idx) | |||||
| # unflatten | |||||
| new_pred_res = self.reform_ids(flatten_pred_res, pred_res) | |||||
| query_string = "abduce_consistent_insts(%s)." % new_pred_res | |||||
| return query_string.replace("'", "").replace("+", "'+'").replace("=", "'='") | |||||
| def consist_rule(self, exs, rules): | |||||
| rule_str = "%s" % rules | |||||
| rule_str = rule_str.replace("'", "") | |||||
| return len(list(self.prolog.query("consistent_inst_feature(%s, %s)." %(exs, rule_str)))) != 0 | |||||
| def abduce_rules(self, pred_res): | |||||
| prolog_rules = list(self.prolog.query("consistent_inst_feature(%s, X)." % pred_res))[0]['X'] | |||||
| rules = [] | |||||
| for rule in prolog_rules: | |||||
| rules.append(rule.value) | |||||
| return rules | |||||
| # def consist_rules(self, pred_res, rules): | |||||
| class RegKB(KBBase): | class RegKB(KBBase): | ||||