|
|
|
@@ -229,30 +229,52 @@ class prolog_KB(KBBase): |
|
|
|
super().__init__(pseudo_label_list) |
|
|
|
self.prolog = pyswip.Prolog() |
|
|
|
|
|
|
|
def logic_forward(self): |
|
|
|
pass |
|
|
|
def logic_forward(self, pseudo_labels): |
|
|
|
result = list(self.prolog.query("logic_forward(%s, Res)." % pseudo_labels))[0]['Res'] |
|
|
|
if result == 'true': |
|
|
|
return True |
|
|
|
elif result == 'false': |
|
|
|
return False |
|
|
|
return result |
|
|
|
|
|
|
|
def _address_pred_res(self, pred_res, address_idx, multiple_predictions): |
|
|
|
import re |
|
|
|
address_pred_res = pred_res.copy() |
|
|
|
if multiple_predictions: |
|
|
|
address_pred_res = flatten(address_pred_res) |
|
|
|
|
|
|
|
for idx in range(len(address_pred_res)): |
|
|
|
if idx in address_idx: |
|
|
|
address_pred_res[idx] = 'P' + str(idx) |
|
|
|
if multiple_predictions: |
|
|
|
address_pred_res = reform_idx(address_pred_res, pred_res) |
|
|
|
|
|
|
|
regex = r"'P\d+'" |
|
|
|
return re.sub(regex, lambda x: x.group().replace("'", ""), str(address_pred_res)) |
|
|
|
|
|
|
|
def get_query_string(self, pred_res, key, address_idx, multiple_predictions): |
|
|
|
query_string = "logic_forward(" |
|
|
|
query_string += self._address_pred_res(pred_res, address_idx, multiple_predictions) |
|
|
|
key_is_none_flag = key is None or (type(key) == list and key[0] is None) |
|
|
|
query_string += ",%s)." % key if not key_is_none_flag else ")." |
|
|
|
return query_string |
|
|
|
|
|
|
|
def _find_candidate_GKB(self): |
|
|
|
pass |
|
|
|
|
|
|
|
def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False): |
|
|
|
candidates = [] |
|
|
|
# print(address_idx) |
|
|
|
query_string = self.get_query_string(pred_res, key, address_idx) |
|
|
|
|
|
|
|
query_string = self.get_query_string(pred_res, key, address_idx, multiple_predictions) |
|
|
|
if multiple_predictions: |
|
|
|
save_pred_res = pred_res |
|
|
|
pred_res = flatten(pred_res) |
|
|
|
|
|
|
|
abduce_c = [list(z.values()) for z in list(self.prolog.query(query_string))] |
|
|
|
for c in abduce_c: |
|
|
|
candidate = pred_res.copy() |
|
|
|
for i, idx in enumerate(address_idx): |
|
|
|
candidate[idx] = c[i] |
|
|
|
|
|
|
|
if multiple_predictions: |
|
|
|
candidate = reform_idx(candidate, save_pred_res) |
|
|
|
|
|
|
|
candidates.append(candidate) |
|
|
|
return candidates |
|
|
|
|
|
|
|
@@ -264,37 +286,12 @@ class add_prolog_KB(prolog_KB): |
|
|
|
self.prolog.assertz("pseudo_label(%s)" % i) |
|
|
|
self.prolog.assertz("addition(Z1, Z2, Res) :- pseudo_label(Z1), pseudo_label(Z2), Res is Z1+Z2") |
|
|
|
|
|
|
|
def logic_forward(self, nums): |
|
|
|
return list(self.prolog.query("addition(%s, %s, Res)." % (nums[0], nums[1])))[0]['Res'] |
|
|
|
|
|
|
|
def get_query_string(self, pred_res, key, address_idx): |
|
|
|
query_string = "addition(" |
|
|
|
for idx, i in enumerate(pred_res): |
|
|
|
tmp = 'Z' + str(idx) + ',' if idx in address_idx else str(i) + ',' |
|
|
|
query_string += tmp |
|
|
|
query_string += "%s)." % key |
|
|
|
return query_string |
|
|
|
|
|
|
|
|
|
|
|
class HED_prolog_KB(prolog_KB): |
|
|
|
def __init__(self, pseudo_label_list=[0, 1, '+', '=']): |
|
|
|
super().__init__(pseudo_label_list) |
|
|
|
self.prolog.consult('./datasets/hed/learn_add.pl') |
|
|
|
|
|
|
|
def logic_forward(self, exs): |
|
|
|
return len(list(self.prolog.query("abduce_consistent_insts([%s])." % exs))) != 0 |
|
|
|
|
|
|
|
def get_query_string(self, pred_res, key, address_idx): |
|
|
|
flatten_pred_res = flatten(pred_res) |
|
|
|
# add variables for prolog |
|
|
|
for idx in range(len(flatten_pred_res)): |
|
|
|
if idx in address_idx: |
|
|
|
flatten_pred_res[idx] = 'X' + str(idx) |
|
|
|
pred_res = reform_idx(flatten_pred_res, pred_res) |
|
|
|
|
|
|
|
query_string = "abduce_consistent_insts(%s)." % pred_res |
|
|
|
return query_string.replace("'", "").replace("+", "'+'").replace("=", "'='") |
|
|
|
|
|
|
|
def consist_rule(self, exs, rules): |
|
|
|
rules = str(rules).replace("\'","") |
|
|
|
return len(list(self.prolog.query("eval_inst_feature(%s, %s)." % (exs, rules)))) != 0 |
|
|
|
|