Browse Source

Update prologKB

pull/3/head
troyyyyy 3 years ago
parent
commit
8ba7c509e8
4 changed files with 42 additions and 41 deletions
  1. +7
    -6
      abl/abducer/abducer_base.py
  2. +30
    -33
      abl/abducer/kb.py
  3. +2
    -2
      abl/framework_hed.py
  4. +3
    -0
      examples/datasets/hed/learn_add.pl

+ 7
- 6
abl/abducer/abducer_base.py View File

@@ -163,7 +163,7 @@ if __name__ == '__main__':
prob1 = [[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]
prob2 = [[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]

kb = add_KB(True)
kb = add_KB(GKB_flag=True)
abd = AbducerBase(kb, 'confidence')
res = abd.abduce(([1, 1], prob1, 8), max_address_num=2, require_more_address=0)
print(res)
@@ -219,14 +219,15 @@ if __name__ == '__main__':
print(res)
print()

kb = HWF_KB(True, len_list=[1, 3, 5], max_err = 0.1)
kb = HWF_KB(GKB_flag=True, len_list=[1, 3, 5], max_err = 0.1)
abd = AbducerBase(kb, 'hamming')
res = abd.abduce((['5', '+', '2'], None, 3), max_address_num=2, require_more_address=0)
print(res)
res = abd.abduce((['5', '+', '9'], None, 64), max_address_num=3, require_more_address=0)
print(res)
print()
kb = HWF_KB(True, len_list=[1, 3, 5], max_err = 1)
kb = HWF_KB(GKB_flag=True, len_list=[1, 3, 5], max_err = 1)
abd = AbducerBase(kb, 'hamming')
res = abd.abduce((['5', '+', '9'], None, 64), max_address_num=3, require_more_address=0)
print(res)
@@ -270,11 +271,11 @@ if __name__ == '__main__':
print(kb.consist_rule([1, '+', 1, '=', 1, 0], rules), kb.consist_rule([1, '+', 1, '=', 1, 1], rules))
print()

res = abd.abduce((consist_exs, None, [1] * len(consist_exs)))
res = abd.abduce((consist_exs, None, [None] * len(consist_exs)))
print(res)
res = abd.abduce((inconsist_exs, None, [1] * len(consist_exs)))
res = abd.abduce((inconsist_exs, None, [None] * len(inconsist_exs)))
print(res)
print()

abduced_rules = abd.abduce_rules(consist_exs)
print(abduced_rules)
print(abduced_rules)

+ 30
- 33
abl/abducer/kb.py View File

@@ -229,30 +229,52 @@ class prolog_KB(KBBase):
super().__init__(pseudo_label_list)
self.prolog = pyswip.Prolog()

def logic_forward(self):
pass
def logic_forward(self, pseudo_labels):
result = list(self.prolog.query("logic_forward(%s, Res)." % pseudo_labels))[0]['Res']
if result == 'true':
return True
elif result == 'false':
return False
return result
def _address_pred_res(self, pred_res, address_idx, multiple_predictions):
import re
address_pred_res = pred_res.copy()
if multiple_predictions:
address_pred_res = flatten(address_pred_res)
for idx in range(len(address_pred_res)):
if idx in address_idx:
address_pred_res[idx] = 'P' + str(idx)
if multiple_predictions:
address_pred_res = reform_idx(address_pred_res, pred_res)
regex = r"'P\d+'"
return re.sub(regex, lambda x: x.group().replace("'", ""), str(address_pred_res))
def get_query_string(self, pred_res, key, address_idx, multiple_predictions):
query_string = "logic_forward("
query_string += self._address_pred_res(pred_res, address_idx, multiple_predictions)
key_is_none_flag = key is None or (type(key) == list and key[0] is None)
query_string += ",%s)." % key if not key_is_none_flag else ")."
return query_string

def _find_candidate_GKB(self):
pass
def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False):
candidates = []
# print(address_idx)
query_string = self.get_query_string(pred_res, key, address_idx)

query_string = self.get_query_string(pred_res, key, address_idx, multiple_predictions)
if multiple_predictions:
save_pred_res = pred_res
pred_res = flatten(pred_res)

abduce_c = [list(z.values()) for z in list(self.prolog.query(query_string))]
for c in abduce_c:
candidate = pred_res.copy()
for i, idx in enumerate(address_idx):
candidate[idx] = c[i]

if multiple_predictions:
candidate = reform_idx(candidate, save_pred_res)

candidates.append(candidate)
return candidates

@@ -264,37 +286,12 @@ class add_prolog_KB(prolog_KB):
self.prolog.assertz("pseudo_label(%s)" % i)
self.prolog.assertz("addition(Z1, Z2, Res) :- pseudo_label(Z1), pseudo_label(Z2), Res is Z1+Z2")

def logic_forward(self, nums):
return list(self.prolog.query("addition(%s, %s, Res)." % (nums[0], nums[1])))[0]['Res']

def get_query_string(self, pred_res, key, address_idx):
query_string = "addition("
for idx, i in enumerate(pred_res):
tmp = 'Z' + str(idx) + ',' if idx in address_idx else str(i) + ','
query_string += tmp
query_string += "%s)." % key
return query_string


class HED_prolog_KB(prolog_KB):
def __init__(self, pseudo_label_list=[0, 1, '+', '=']):
super().__init__(pseudo_label_list)
self.prolog.consult('./datasets/hed/learn_add.pl')

def logic_forward(self, exs):
return len(list(self.prolog.query("abduce_consistent_insts([%s])." % exs))) != 0

def get_query_string(self, pred_res, key, address_idx):
flatten_pred_res = flatten(pred_res)
# add variables for prolog
for idx in range(len(flatten_pred_res)):
if idx in address_idx:
flatten_pred_res[idx] = 'X' + str(idx)
pred_res = reform_idx(flatten_pred_res, pred_res)

query_string = "abduce_consistent_insts(%s)." % pred_res
return query_string.replace("'", "").replace("+", "'+'").replace("=", "'='")

def consist_rule(self, exs, rules):
rules = str(rules).replace("\'","")
return len(list(self.prolog.query("eval_inst_feature(%s, %s)." % (exs, rules)))) != 0


+ 2
- 2
abl/framework_hed.py View File

@@ -150,7 +150,7 @@ def abduce_and_train(model, abducer, mapping, train_X_true, select_num):
for m in mappings:
pred_res = mapping_res(original_pred_res, m)
max_abduce_num = 20
solution = abducer.zoopt_get_solution(pred_res, [1] * len(pred_res), max_abduce_num)
solution = abducer.zoopt_get_solution(pred_res, [None] * len(pred_res), max_abduce_num)
all_address_flag = reform_idx(solution, pred_res)

consistent_idx_tmp = []
@@ -158,7 +158,7 @@ def abduce_and_train(model, abducer, mapping, train_X_true, select_num):
for idx in range(len(pred_res)):
address_idx = [i for i, flag in enumerate(all_address_flag[idx]) if flag != 0]
candidate = abducer.kb.address_by_idx([pred_res[idx]], 1, address_idx, True)
candidate = abducer.kb.address_by_idx([pred_res[idx]], None, address_idx, True)
if len(candidate) > 0:
consistent_idx_tmp.append(idx)
consistent_pred_res_tmp.append(candidate[0][0])


+ 3
- 0
examples/datasets/hed/learn_add.pl View File

@@ -32,6 +32,9 @@ abduce_consistent_insts(Exs):-
% (Experimental) Uncomment to use parallel abduction
% abduce_consistent_exs_concurrent(Exs), !.
logic_forward(Exs, X) :- abduce_consistent_insts([Exs]) -> X = true ; X = false.
logic_forward(Exs) :- abduce_consistent_insts(Exs).
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%% Abduce Delta_C given pseudo-labels
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%


Loading…
Cancel
Save