Browse Source

Add kb and abducer for HED

pull/3/head
troyyyyy GitHub 3 years ago
parent
commit
c4c2f7e030
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 98 additions and 35 deletions
  1. +98
    -35
      abducer/kb.py

+ 98
- 35
abducer/kb.py View File

@@ -22,7 +22,7 @@ import pyswip


class KBBase(ABC):
def __init__(self):
def __init__(self, pseudo_label_list = None):
pass

@abstractmethod
@@ -34,16 +34,16 @@ class KBBase(ABC):
pass


def address(self, address_num, pred_res, key):
def address(self, address_num, pred_res, key, multiple_predictions = False):
new_candidates = []
address_idx_list = list(combinations(list(range(len(pred_res))), address_num))
for address_idx in address_idx_list:
candidates = self.address_by_idx(pred_res, key, address_idx)
candidates = self.address_by_idx(pred_res, key, address_idx, multiple_predictions)
new_candidates += candidates
return new_candidates
def abduction(self, pred_res, key, max_address_num, require_more_address):
def abduction(self, pred_res, key, max_address_num, require_more_address, multiple_predictions = False):
candidates = []
for address_num in range(len(pred_res) + 1):
@@ -51,7 +51,7 @@ class KBBase(ABC):
if abs(self.logic_forward(pred_res) - key) <= 1e-3:
candidates.append(pred_res)
else:
new_candidates = self.address(address_num, pred_res, key)
new_candidates = self.address(address_num, pred_res, key, multiple_predictions)
candidates += new_candidates
if len(candidates) > 0:
@@ -64,12 +64,30 @@ class KBBase(ABC):
for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1):
if address_num > max_address_num:
return candidates, min_address_num, address_num - 1
new_candidates = self.address(address_num, pred_res, key)
new_candidates = self.address(address_num, pred_res, key, multiple_predictions)
candidates += new_candidates

return candidates, min_address_num, address_num
# for multiple predictions, modify from `learn_add.py`
def flatten(self, l):
return [item for sublist in l for item in sublist]
# for multiple predictions, modify from `learn_add.py`
def reform_ids(self, flatten_pred_res, save_pred_res):
re = []
i = 0
for e in save_pred_res:
j = 0
ids = []
while j < len(e):
ids.append(flatten_pred_res[i + j])
j += 1
re.append(ids)
i = i + j
return re
def __len__(self):
pass

@@ -82,7 +100,6 @@ class ClsKB(KBBase):
self.prolog_flag = False
if GKB_flag:
# self.base = np.load('abducer/hwf.npy', allow_pickle=True).item()
self.base = {}
X, Y = self.get_GKB(self.pseudo_label_list, self.len_list)
for x, y in zip(X, Y):
@@ -109,11 +126,11 @@ class ClsKB(KBBase):
def logic_forward(self):
pass
def abduce_candidates(self, pred_res, key, max_address_num = -1, require_more_address = 0):
def abduce_candidates(self, pred_res, key, max_address_num = -1, require_more_address = 0, multiple_predictions = False):
if self.GKB_flag:
return self.abduce_from_GKB(pred_res, key, max_address_num, require_more_address)
else:
return self.abduction(pred_res, key, max_address_num, require_more_address)
return self.abduction(pred_res, key, max_address_num, require_more_address, multiple_predictions)



@@ -142,14 +159,23 @@ class ClsKB(KBBase):
return candidates, min_address_num, address_num

def address_by_idx(self, pred_res, key, address_idx):
def address_by_idx(self, pred_res, key, address_idx, multiple_predictions = False):
candidates = []
abduce_c = self.all_address_candidate_dict[len(address_idx)]
if multiple_predictions:
save_pred_res = pred_res
pred_res = self.flatten(pred_res)
for c in abduce_c:
candidate = pred_res.copy()
for i, idx in enumerate(address_idx):
candidate[idx] = c[i]
if self.logic_forward(candidate) == key:
if multiple_predictions:
candidate = self.reform_ids(candidate, save_pred_res)
if self.logic_forward(candidate) == key:
candidates.append(candidate)
return candidates
@@ -176,7 +202,7 @@ class add_KB(ClsKB):
def logic_forward(self, nums):
return sum(nums)
class hwf_KB(ClsKB):
class HWF_KB(ClsKB):
def __init__(self, GKB_flag = False, \
pseudo_label_list = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], \
len_list = [1, 3, 5, 7]):
@@ -205,62 +231,99 @@ class prolog_KB(KBBase):
super().__init__()
self.pseudo_label_list = pseudo_label_list
self.prolog = pyswip.Prolog()
for i in self.pseudo_label_list:
self.prolog.assertz("pseudo_label(%s)" % i)
def logic_forward(self):
pass
def abduce_candidates(self, pred_res, key, max_address_num, require_more_address):
return self.abduction(pred_res, key, max_address_num, require_more_address)
def abduce_candidates(self, pred_res, key, max_address_num, require_more_address, multiple_predictions):
return self.abduction(pred_res, key, max_address_num, require_more_address, multiple_predictions)
def address_by_idx(self, pred_res, key, address_idx, verbose=True):
def address_by_idx(self, pred_res, key, address_idx, multiple_predictions = False):
candidates = []
query_string = self.get_query_string(pred_res, address_idx)
abduce_c = [list(z.values()) for z in list(self.prolog.query(query_string % key))]
if not multiple_predictions:
query_string = self.get_query_string(pred_res, key, address_idx)
else:
query_string = self.get_query_string_need_flatten(pred_res, key, address_idx)
if multiple_predictions:
save_pred_res = pred_res
pred_res = self.flatten(pred_res)

abduce_c = [list(z.values()) for z in list(self.prolog.query(query_string))]
for c in abduce_c:
candidate = pred_res.copy()
for i, idx in enumerate(address_idx):
candidate[idx] = c[i]
if multiple_predictions:
candidate = self.reform_ids(candidate, save_pred_res)
candidates.append(candidate)
return candidates
return candidates

class add_prolog_KB(prolog_KB):
def __init__(self, pseudo_label_list = list(range(10))):
super().__init__(pseudo_label_list)
for i in self.pseudo_label_list:
self.prolog.assertz("pseudo_label(%s)" % i)
self.prolog.assertz("addition(Z1, Z2, Res) :- pseudo_label(Z1), pseudo_label(Z2), Res is Z1+Z2")
def logic_forward(self, nums):
return list(self.prolog.query("addition(%s, %s, Res)." %(nums[0], nums[1])))[0]['Res']
def get_query_string(self, pred_res, address_idx):
def get_query_string(self, pred_res, key, address_idx):
query_string = "addition("
for idx, i in enumerate(pred_res):
tmp = 'Z' + str(idx) + ',' if idx in address_idx else str(i) + ','
query_string += tmp
query_string += "%s)."
query_string += "%s)." % key
return query_string

class hed_prolog_KB(prolog_KB):
def __init__(self, pseudo_label_list = list(range(10))):
class HED_prolog_KB(prolog_KB):
def __init__(self, pseudo_label_list = [0, 1, '+', '=']):
super().__init__(pseudo_label_list)
self.prolog.assertz("addition(Z1, Z2, Res) :- pseudo_label(Z1), pseudo_label(Z2), Res is Z1+Z2")
self.prolog.consult('../datasets/hed/learn_add.pl')
def logic_forward(self, nums):
return list(self.prolog.query("addition(%s, %s, Res)." %(nums[0], nums[1])))[0]['Res']
# corresponding to `con_sol is not None` in `consistent_score_mapped` within `learn_add.py`
def logic_forward(self, exs):
return len(list(self.prolog.query("abduce_consistent_insts(%s)." % exs))) != 0
def get_query_string(self, pred_res, address_idx):
query_string = "addition("
for idx, i in enumerate(pred_res):
tmp = 'Z' + str(idx) + ',' if idx in address_idx else str(i) + ','
query_string += tmp
query_string += "%s)."
return query_string

def get_query_string_need_flatten(self, pred_res, key, address_idx):
# flatten
flatten_pred_res = self.flatten(pred_res)
# add variables for prolog
for idx in range(len(flatten_pred_res)):
if idx in address_idx:
flatten_pred_res[idx] = 'X' + str(idx)
# unflatten
new_pred_res = self.reform_ids(flatten_pred_res, pred_res)
query_string = "abduce_consistent_insts(%s)." % new_pred_res
return query_string.replace("'", "").replace("+", "'+'").replace("=", "'='")
def consist_rule(self, exs, rules):
rule_str = "%s" % rules
rule_str = rule_str.replace("'", "")
return len(list(self.prolog.query("consistent_inst_feature(%s, %s)." %(exs, rule_str)))) != 0
def abduce_rules(self, pred_res):
prolog_rules = list(self.prolog.query("consistent_inst_feature(%s, X)." % pred_res))[0]['X']
rules = []
for rule in prolog_rules:
rules.append(rule.value)
return rules
# def consist_rules(self, pred_res, rules):


class RegKB(KBBase):


Loading…
Cancel
Save