Browse Source

Update kb.py

pull/3/head
troyyyyy GitHub 3 years ago
parent
commit
cdb8bc2067
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 63 additions and 107 deletions
  1. +63
    -107
      abducer/kb.py

+ 63
- 107
abducer/kb.py View File

@@ -32,23 +32,32 @@ class KBBase(ABC):
def abduce_candidates(self):
pass
def filter_all_candidates(self, pred_res, all_candidates, max_address_num, require_more_address):
if len(all_candidates) == 0:
candidates = []
min_address_num = 0
address_num = 0
else:
cost_list = self.hamming_dist(pred_res, all_candidates)
min_address_num = np.min(cost_list)
address_num = min(max_address_num, min_address_num + require_more_address)
idxs = np.where(cost_list <= address_num)[0]
candidates = [all_candidates[idx] for idx in idxs]
def abduction(self, pred_res, key, max_address_num, require_more_address):
candidates = []
for address_num in range(len(pred_res) + 1):
if address_num == 0:
if abs(self.logic_forward(pred_res) - key) <= 1e-3:
candidates.append(pred_res)
else:
new_candidates = self.address(address_num, pred_res, key)
candidates += new_candidates
if len(candidates) > 0:
min_address_num = address_num
break
if address_num >= max_address_num:
return [], 0, 0
for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1):
if address_num > max_address_num:
return candidates, min_address_num, address_num - 1
new_candidates = self.address(address_num, pred_res, key)
candidates += new_candidates

return candidates, min_address_num, address_num
def hamming_dist(self, A, B):
B = np.array(B)
A = np.expand_dims(A, axis = 0).repeat(axis=0, repeats=(len(B)))
return np.sum(A != B, axis = 1)
def __len__(self):
pass
@@ -90,66 +99,38 @@ class ClsKB(KBBase):
pass
def abduce_candidates(self, pred_res, key, max_address_num = -1, require_more_address = 0):
if max_address_num == -1:
max_address_num = len(pred_res)
if self.GKB_flag:
all_candidates = self.get_candidates_GKB(key, len(pred_res))
return self.filter_all_candidates(pred_res, all_candidates, max_address_num, require_more_address)
return self.abduce_from_GKB(pred_res, key, max_address_num, require_more_address)
else:
return self.abduction(pred_res, key, max_address_num, require_more_address)



def get_candidates_GKB(self, key, length = None):
if self.base == {}:
def abduce_from_GKB(self, pred_res, key, max_address_num, require_more_address):
if self.base == {} or len(pred_res) not in self.len_list:
return []
if key is None:
return self.get_all_candidates()
all_candidates = self.base[len(pred_res)][key]
if length is None:
length = list(self.base.keys())
elif type(length) is int and length not in self.len_list:
return []
if len(all_candidates) == 0:
candidates = []
min_address_num = 0
address_num = 0
else:
length = [length]
return sum([self.base[l][key] for l in length], [])
cost_list = self.hamming_dist(pred_res, all_candidates)
min_address_num = np.min(cost_list)
address_num = min(max_address_num, min_address_num + require_more_address)
idxs = np.where(cost_list <= address_num)[0]
candidates = [all_candidates[idx] for idx in idxs]
return candidates, min_address_num, address_num
def get_all_candidates(self):
if self.base == {}:
return []
else:
return sum([sum(v.values(), []) for v in self.base.values()], [])
def hamming_dist(self, A, B):
B = np.array(B)
A = np.expand_dims(A, axis = 0).repeat(axis=0, repeats=(len(B)))
return np.sum(A != B, axis = 1)



def abduction(self, pred_res, key, max_address_num, require_more_address):
candidates = []
for address_num in range(len(pred_res) + 1):
if address_num == 0:
if abs(self.logic_forward(pred_res) - key) <= 1e-3:
candidates.append(pred_res)
else:
new_candidates = self.address(address_num, pred_res, key)
candidates += new_candidates
if len(candidates) > 0:
min_address_num = address_num
break
if address_num >= max_address_num:
return [], 0, 0
for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1):
if address_num > max_address_num:
return candidates, min_address_num, address_num - 1
new_candidates = self.address(address_num, pred_res, key)
candidates += new_candidates

return candidates, min_address_num, address_num
def address(self, address_num, pred_res, key):
new_candidates = []
@@ -214,9 +195,6 @@ class hwf_KB(ClsKB):
return round(eval(''.join(formula)), 2)





class prolog_KB(KBBase):
def __init__(self, pseudo_label_list):
super().__init__()
@@ -224,16 +202,29 @@ class prolog_KB(KBBase):
self.prolog = pyswip.Prolog()
for i in self.pseudo_label_list:
self.prolog.assertz("pseudo_label(%s)" % i)
self.prolog_flag = True
def logic_forward(self):
pass
def abduce_candidates(self, pred_res, key, max_address_num = -1, require_more_address = 0):
if max_address_num == -1:
max_address_num = len(pred_res)
all_candidates = self.get_candidates_prolog(key)
return self.filter_all_candidates(pred_res, all_candidates, max_address_num, require_more_address)
def abduce_candidates(self, pred_res, key, max_address_num, require_more_address):
return self.abduction(pred_res, key, max_address_num, require_more_address)
def address(self, address_num, pred_res, key):
new_candidates = []
address_idx_list = list(combinations(list(range(len(pred_res))), address_num))
for address_idx in address_idx_list:
query_string = "addition("
for idx, i in enumerate(pred_res):
tmp = 'Z' + str(idx) + ',' if idx in address_idx else str(i) + ','
query_string += tmp
query_string += "%s)."
abduce_c = [list(z.values()) for z in list(self.prolog.query(query_string % key))]
for c in abduce_c:
candidate = pred_res.copy()
for i, idx in enumerate(address_idx):
candidate[idx] = c[i]
new_candidates.append(candidate)
return new_candidates

class add_prolog_KB(prolog_KB):
@@ -301,42 +292,7 @@ class RegKB(KBBase):

import time
if __name__ == "__main__":
# With ground KB
kb = add_KB(GKB_flag = True)
print('len(kb):', len(kb))
res = kb.get_candidates_GKB(0)
print(res)
res = kb.get_candidates_GKB(18)
print(res)
res = kb.get_candidates_GKB(18)
print(res)
res = kb.get_candidates_GKB(16)
print(res)
print()
# Without ground KB
kb = add_KB()
print('len(kb):', len(kb))
print()
# Prolog
kb = add_prolog_KB()
print(kb.logic_forward([3, 4]))
res = kb.get_candidates_prolog(16)
print(res)
start = time.time()
kb = hwf_KB(GKB_flag = True, len_list = [1, 3, 5])
print(time.time() - start)
print('len(kb):', len(kb))
res = kb.get_candidates_GKB(2, length = 1)
print(res)
res = kb.get_candidates_GKB(1, length = 3)
print(res)
res = kb.get_candidates_GKB(3.67, length = 5)
print(res)
print()
pass
# X = ["1+1", "0+1", "1+0", "2+0", "1+0+1"]


Loading…
Cancel
Save