Browse Source

Update kb.py

pull/3/head
troyyyyy GitHub 3 years ago
parent
commit
d4ef2ca8d9
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 25 additions and 47 deletions
  1. +25
    -47
      abducer/kb.py

+ 25
- 47
abducer/kb.py View File

@@ -15,8 +15,12 @@ import bisect
import copy
import numpy as np

import sys
sys.path.append("..")

from collections import defaultdict
from itertools import product, combinations
from utils.utils import _flatten, _reform_ids, _hamming_dist

import pyswip

@@ -39,25 +43,21 @@ class KBBase(ABC):
if not multiple_predictions:
address_idx_list = list(combinations(list(range(len(pred_res))), address_num))
else:
address_idx_list = list(combinations(list(range(len(self.flatten(pred_res)))), address_num))
address_idx_list = list(combinations(list(range(len(_flatten(pred_res)))), address_num))
for address_idx in address_idx_list:
candidates = self.address_by_idx(pred_res, key, address_idx, multiple_predictions)
new_candidates += candidates
return new_candidates
def correct_result(self, pred_res, key):
if type(key) != bool:
return abs(self.logic_forward(pred_res) - key) <= 1e-3
else:
return self.logic_forward(pred_res)
def abduction(self, pred_res, key, max_address_num, require_more_address, multiple_predictions = False):
candidates = []
for address_num in range(len(pred_res) + 1):
if address_num == 0:
if self.correct_result(pred_res, key):
if abs(self.logic_forward(pred_res) - key) <= 1e-3:
candidates.append(pred_res)
else:
new_candidates = self.address(address_num, pred_res, key, multiple_predictions)
@@ -79,23 +79,7 @@ class KBBase(ABC):
return candidates, min_address_num, address_num
# for multiple predictions, modify from `learn_add.py`
def flatten(self, l):
return [item for sublist in l for item in sublist]
# for multiple predictions, modify from `learn_add.py`
def reform_ids(self, flatten_pred_res, save_pred_res):
re = []
i = 0
for e in save_pred_res:
j = 0
ids = []
while j < len(e):
ids.append(flatten_pred_res[i + j])
j += 1
re.append(ids)
i = i + j
return re

def __len__(self):
pass
@@ -110,7 +94,7 @@ class ClsKB(KBBase):
if GKB_flag:
self.base = {}
X, Y = self.get_GKB(self.pseudo_label_list, self.len_list)
X, Y = self._get_GKB(self.pseudo_label_list, self.len_list)
for x, y in zip(X, Y):
self.base.setdefault(len(x), defaultdict(list))[y].append(x)
else:
@@ -118,7 +102,7 @@ class ClsKB(KBBase):
for address_num in range(max(self.len_list) + 1):
self.all_address_candidate_dict[address_num] = list(product(self.pseudo_label_list, repeat = address_num))
def get_GKB(self, pseudo_label_list, len_list):
def _get_GKB(self, pseudo_label_list, len_list):
all_X = []
for len in len_list:
all_X += list(product(pseudo_label_list, repeat = len))
@@ -142,12 +126,6 @@ class ClsKB(KBBase):
return self.abduction(pred_res, key, max_address_num, require_more_address, multiple_predictions)



def hamming_dist(self, A, B):
B = np.array(B)
A = np.expand_dims(A, axis = 0).repeat(axis=0, repeats=(len(B)))
return np.sum(A != B, axis = 1)

def abduce_from_GKB(self, pred_res, key, max_address_num, require_more_address):
if self.base == {} or len(pred_res) not in self.len_list:
return []
@@ -159,7 +137,7 @@ class ClsKB(KBBase):
min_address_num = 0
address_num = 0
else:
cost_list = self.hamming_dist(pred_res, all_candidates)
cost_list = _hamming_dist(pred_res, all_candidates)
min_address_num = np.min(cost_list)
address_num = min(max_address_num, min_address_num + require_more_address)
idxs = np.where(cost_list <= address_num)[0]
@@ -174,7 +152,7 @@ class ClsKB(KBBase):
if multiple_predictions:
save_pred_res = pred_res
pred_res = self.flatten(pred_res)
pred_res = _flatten(pred_res)
for c in abduce_c:
candidate = pred_res.copy()
@@ -182,7 +160,7 @@ class ClsKB(KBBase):
candidate[idx] = c[i]
if multiple_predictions:
candidate = self.reform_ids(candidate, save_pred_res)
candidate = _reform_ids(candidate, save_pred_res)
if self.logic_forward(candidate) == key:
candidates.append(candidate)
@@ -252,15 +230,15 @@ class prolog_KB(KBBase):
def address_by_idx(self, pred_res, key, address_idx, multiple_predictions = False):
candidates = []
# print(address_idx)
if not multiple_predictions:
query_string = self.get_query_string(pred_res, key, address_idx)
else:
query_string = self.get_query_string_need_flatten(pred_res, key, address_idx)
query_string = self.get_query_string_need__flatten(pred_res, key, address_idx)
if multiple_predictions:
save_pred_res = pred_res
pred_res = self.flatten(pred_res)
pred_res = _flatten(pred_res)

abduce_c = [list(z.values()) for z in list(self.prolog.query(query_string))]
for c in abduce_c:
@@ -269,7 +247,7 @@ class prolog_KB(KBBase):
candidate[idx] = c[i]
if multiple_predictions:
candidate = self.reform_ids(candidate, save_pred_res)
candidate = _reform_ids(candidate, save_pred_res)
candidates.append(candidate)
return candidates
@@ -297,22 +275,22 @@ class add_prolog_KB(prolog_KB):
class HED_prolog_KB(prolog_KB):
def __init__(self, pseudo_label_list = [0, 1, '+', '=']):
super().__init__(pseudo_label_list)
self.prolog.consult('../datasets/hed/learn_add.pl')
self.prolog.consult('./datasets/hed/learn_add.pl')
# corresponding to `con_sol is not None` in `consistent_score_mapped` within `learn_add.py`
def logic_forward(self, exs):
return len(list(self.prolog.query("abduce_consistent_insts(%s)." % exs))) != 0

def get_query_string_need_flatten(self, pred_res, key, address_idx):
# flatten
flatten_pred_res = self.flatten(pred_res)
def get_query_string_need__flatten(self, pred_res, key, address_idx):
# _flatten
_flatten_pred_res = _flatten(pred_res)
# add variables for prolog
for idx in range(len(flatten_pred_res)):
for idx in range(len(_flatten_pred_res)):
if idx in address_idx:
flatten_pred_res[idx] = 'X' + str(idx)
# unflatten
new_pred_res = self.reform_ids(flatten_pred_res, pred_res)
_flatten_pred_res[idx] = 'X' + str(idx)
# un_flatten
new_pred_res = _reform_ids(_flatten_pred_res, pred_res)
query_string = "abduce_consistent_insts(%s)." % new_pred_res
return query_string.replace("'", "").replace("+", "'+'").replace("=", "'='")


Loading…
Cancel
Save