Browse Source

Rearrange abduce_by_GKB and address_by_idx to Base

pull/3/head
troyyyyy GitHub 2 years ago
parent
commit
4478e4e3de
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 80 additions and 178 deletions
  1. +80
    -178
      abl/abducer/kb.py

+ 80
- 178
abl/abducer/kb.py View File

@@ -10,20 +10,6 @@
# #
# ================================================================# # ================================================================#


from abc import ABC, abstractmethod
import bisect
import copy
import numpy as np

from collections import defaultdict
from itertools import product, combinations
from ..utils.utils import flatten, reform_idx, hamming_dist, check_equal

from multiprocessing import Pool

import pyswip


class KBBase(ABC): class KBBase(ABC):
def __init__(self, pseudo_label_list=None, len_list=None, GKB_flag=False, max_err=0): def __init__(self, pseudo_label_list=None, len_list=None, GKB_flag=False, max_err=0):
self.pseudo_label_list = pseudo_label_list self.pseudo_label_list = pseudo_label_list
@@ -79,78 +65,16 @@ class KBBase(ABC):
res.append(self.logic_forward(x)) res.append(self.logic_forward(x))
return res return res


@abstractmethod
def abduce_candidates(self):
pass
@abstractmethod
def address_by_idx(self):
pass

def _address(self, address_num, pred_res, key, multiple_predictions):
new_candidates = []
if not multiple_predictions:
address_idx_list = list(combinations(list(range(len(pred_res))), address_num))
else:
address_idx_list = list(combinations(list(range(len(flatten(pred_res)))), address_num))

for address_idx in address_idx_list:
candidates = self.address_by_idx(pred_res, key, address_idx, multiple_predictions)
new_candidates += candidates
return new_candidates

def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address, multiple_predictions):
candidates = []

for address_num in range(len(flatten(pred_res)) + 1):
if address_num == 0:
if check_equal(self._logic_forward(pred_res, multiple_predictions), key, self.max_err):
candidates.append(pred_res)
else:
new_candidates = self._address(address_num, pred_res, key, multiple_predictions)
candidates += new_candidates

if len(candidates) > 0:
min_address_num = address_num
break

if address_num >= max_address_num:
return [], 0, 0

for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1):
if address_num > max_address_num:
return candidates, min_address_num, address_num - 1
new_candidates = self._address(address_num, pred_res, key, multiple_predictions)
candidates += new_candidates

return candidates, min_address_num, address_num

def _dict_len(self, dic):
if not self.GKB_flag:
return 0
else:
return sum(len(c) for c in dic.values())

def __len__(self):
if not self.GKB_flag:
return 0
else:
return sum(self._dict_len(v) for v in self.base.values())


class ClsKB(KBBase):
def __init__(self, pseudo_label_list, len_list, GKB_flag):
super().__init__(pseudo_label_list, len_list, GKB_flag)

def logic_forward(self):
pass

def abduce_candidates(self, pred_res, key, max_address_num=-1, require_more_address=0, multiple_predictions=False): def abduce_candidates(self, pred_res, key, max_address_num=-1, require_more_address=0, multiple_predictions=False):
if self.GKB_flag: if self.GKB_flag:
return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address, multiple_predictions) return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address, multiple_predictions)
else: else:
return self._abduce_by_search(pred_res, key, max_address_num, require_more_address, multiple_predictions) return self._abduce_by_search(pred_res, key, max_address_num, require_more_address, multiple_predictions)

@abstractmethod
def _find_candidate_GKB(self):
pass
def _abduce_by_GKB(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): def _abduce_by_GKB(self, pred_res, key, max_address_num, require_more_address, multiple_predictions):
if self.base == {}: if self.base == {}:
return [], 0, 0 return [], 0, 0
@@ -158,7 +82,7 @@ class ClsKB(KBBase):
if not multiple_predictions: if not multiple_predictions:
if len(pred_res) not in self.len_list: if len(pred_res) not in self.len_list:
return [], 0, 0 return [], 0, 0
all_candidates = self.base[len(pred_res)][key]
all_candidates = self._find_candidate_GKB(pred_res, key)
if len(all_candidates) == 0: if len(all_candidates) == 0:
return [], 0, 0 return [], 0, 0
else: else:
@@ -168,7 +92,7 @@ class ClsKB(KBBase):
idxs = np.where(cost_list <= address_num)[0] idxs = np.where(cost_list <= address_num)[0]
candidates = [all_candidates[idx] for idx in idxs] candidates = [all_candidates[idx] for idx in idxs]
return candidates, min_address_num, address_num return candidates, min_address_num, address_num
else: else:
min_address_num = 0 min_address_num = 0
all_candidates_save = [] all_candidates_save = []
@@ -177,7 +101,7 @@ class ClsKB(KBBase):
for p_res, k in zip(pred_res, key): for p_res, k in zip(pred_res, key):
if len(p_res) not in self.len_list: if len(p_res) not in self.len_list:
return [], 0, 0 return [], 0, 0
all_candidates = self.base[len(p_res)][k]
all_candidates = self._regression_find_candidate_GKB(p_res, k)
if len(all_candidates) == 0: if len(all_candidates) == 0:
return [], 0, 0 return [], 0, 0
else: else:
@@ -194,7 +118,7 @@ class ClsKB(KBBase):
idxs = np.where(multiple_cost_list <= address_num)[0] idxs = np.where(multiple_cost_list <= address_num)[0]
candidates = [reform_idx(multiple_all_candidates[idx], pred_res) for idx in idxs] candidates = [reform_idx(multiple_all_candidates[idx], pred_res) for idx in idxs]
return candidates, min_address_num, address_num return candidates, min_address_num, address_num
def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False): def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False):
candidates = [] candidates = []
abduce_c = list(product(self.pseudo_label_list, repeat=len(address_idx))) abduce_c = list(product(self.pseudo_label_list, repeat=len(address_idx)))
@@ -211,10 +135,71 @@ class ClsKB(KBBase):
if multiple_predictions: if multiple_predictions:
candidate = reform_idx(candidate, save_pred_res) candidate = reform_idx(candidate, save_pred_res)


if check_equal(self._logic_forward(candidate, multiple_predictions), key):
if check_equal(self._logic_forward(candidate, multiple_predictions), key, self.max_err):
candidates.append(candidate) candidates.append(candidate)
return candidates return candidates


def _address(self, address_num, pred_res, key, multiple_predictions):
new_candidates = []
if not multiple_predictions:
address_idx_list = list(combinations(list(range(len(pred_res))), address_num))
else:
address_idx_list = list(combinations(list(range(len(flatten(pred_res)))), address_num))

for address_idx in address_idx_list:
candidates = self.address_by_idx(pred_res, key, address_idx, multiple_predictions)
new_candidates += candidates
return new_candidates

def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address, multiple_predictions):
candidates = []

for address_num in range(len(flatten(pred_res)) + 1):
if address_num == 0:
if check_equal(self._logic_forward(pred_res, multiple_predictions), key, self.max_err):
candidates.append(pred_res)
else:
new_candidates = self._address(address_num, pred_res, key, multiple_predictions)
candidates += new_candidates

if len(candidates) > 0:
min_address_num = address_num
break

if address_num >= max_address_num:
return [], 0, 0

for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1):
if address_num > max_address_num:
return candidates, min_address_num, address_num - 1
new_candidates = self._address(address_num, pred_res, key, multiple_predictions)
candidates += new_candidates

return candidates, min_address_num, address_num

def _dict_len(self, dic):
if not self.GKB_flag:
return 0
else:
return sum(len(c) for c in dic.values())

def __len__(self):
if not self.GKB_flag:
return 0
else:
return sum(self._dict_len(v) for v in self.base.values())


class ClsKB(KBBase):
def __init__(self, pseudo_label_list, len_list, GKB_flag):
super().__init__(pseudo_label_list, len_list, GKB_flag)

def logic_forward(self):
pass

def _find_candidate_GKB(self, pred_res, key):
return self.base[len(pred_res)][key]



class add_KB(ClsKB): class add_KB(ClsKB):
def __init__(self, pseudo_label_list=list(range(10)), len_list=[2], GKB_flag=False): def __init__(self, pseudo_label_list=list(range(10)), len_list=[2], GKB_flag=False):
@@ -232,16 +217,13 @@ class prolog_KB(KBBase):
def logic_forward(self): def logic_forward(self):
pass pass


def abduce_candidates(self, pred_res, key, max_address_num, require_more_address, multiple_predictions):
return self._abduce_by_search(pred_res, key, max_address_num, require_more_address, multiple_predictions)
def _find_candidate_GKB(self):
pass
def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False): def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False):
candidates = [] candidates = []
# print(address_idx) # print(address_idx)
if not multiple_predictions:
query_string = self.get_query_string(pred_res, key, address_idx)
else:
query_string = self.get_query_string_need_flatten(pred_res, key, address_idx)
query_string = self.get_query_string(pred_res, key, address_idx)


if multiple_predictions: if multiple_predictions:
save_pred_res = pred_res save_pred_res = pred_res
@@ -284,21 +266,18 @@ class HED_prolog_KB(prolog_KB):
super().__init__(pseudo_label_list) super().__init__(pseudo_label_list)
self.prolog.consult('./datasets/hed/learn_add.pl') self.prolog.consult('./datasets/hed/learn_add.pl')


# corresponding to `con_sol is not None` in `consistent_score_mapped` within `learn_add.py`
def logic_forward(self, exs): def logic_forward(self, exs):
return len(list(self.prolog.query("abduce_consistent_insts([%s])." % exs))) != 0 return len(list(self.prolog.query("abduce_consistent_insts([%s])." % exs))) != 0


def get_query_string_need_flatten(self, pred_res, key, address_idx):
# flatten
def get_query_string(self, pred_res, key, address_idx):
flatten_pred_res = flatten(pred_res) flatten_pred_res = flatten(pred_res)
# add variables for prolog # add variables for prolog
for idx in range(len(flatten_pred_res)): for idx in range(len(flatten_pred_res)):
if idx in address_idx: if idx in address_idx:
flatten_pred_res[idx] = 'X' + str(idx) flatten_pred_res[idx] = 'X' + str(idx)
# unflatten
new_pred_res = reform_idx(flatten_pred_res, pred_res)
pred_res = reform_idx(flatten_pred_res, pred_res)


query_string = "abduce_consistent_insts(%s)." % new_pred_res
query_string = "abduce_consistent_insts(%s)." % pred_res
return query_string.replace("'", "").replace("+", "'+'").replace("=", "'='") return query_string.replace("'", "").replace("+", "'+'").replace("=", "'='")


def consist_rule(self, exs, rules): def consist_rule(self, exs, rules):
@@ -324,13 +303,7 @@ class RegKB(KBBase):
def logic_forward(self): def logic_forward(self):
pass pass


def abduce_candidates(self, pred_res, key, max_address_num=-1, require_more_address=0, multiple_predictions=False):
if self.GKB_flag:
return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address, multiple_predictions)
else:
return self._abduce_by_search(pred_res, key, max_address_num, require_more_address, multiple_predictions)

def _regression_find_candidate_GKB(self, pred_res, key):
def _find_candidate_GKB(self, pred_res, key):
potential_candidates = self.base[len(pred_res)] potential_candidates = self.base[len(pred_res)]
key_list = sorted(potential_candidates) key_list = sorted(potential_candidates)
key_idx = bisect.bisect_left(key_list, key) key_idx = bisect.bisect_left(key_list, key)
@@ -351,70 +324,6 @@ class RegKB(KBBase):
break break
return all_candidates return all_candidates
def _abduce_by_GKB(self, pred_res, key, max_address_num, require_more_address, multiple_predictions):
if self.base == {}:
return [], 0, 0

if not multiple_predictions:
if len(pred_res) not in self.len_list:
return [], 0, 0
all_candidates = self._regression_find_candidate_GKB(pred_res, key)
if len(all_candidates) == 0:
return [], 0, 0
else:
cost_list = hamming_dist(pred_res, all_candidates)
min_address_num = np.min(cost_list)
address_num = min(max_address_num, min_address_num + require_more_address)
idxs = np.where(cost_list <= address_num)[0]
candidates = [all_candidates[idx] for idx in idxs]
return candidates, min_address_num, address_num
else:
min_address_num = 0
all_candidates_save = []
cost_list_save = []
for p_res, k in zip(pred_res, key):
if len(p_res) not in self.len_list:
return [], 0, 0
all_candidates = self._regression_find_candidate_GKB(p_res, k)
if len(all_candidates) == 0:
return [], 0, 0
else:
all_candidates_save.append(all_candidates)
cost_list = hamming_dist(p_res, all_candidates)
min_address_num += np.min(cost_list)
cost_list_save.append(cost_list)
multiple_all_candidates = [flatten(c) for c in product(*all_candidates_save)]
assert len(multiple_all_candidates[0]) == len(flatten(pred_res))
multiple_cost_list = np.array([sum(cost) for cost in product(*cost_list_save)])
assert len(multiple_all_candidates) == len(multiple_cost_list)
address_num = min(max_address_num, min_address_num + require_more_address)
idxs = np.where(multiple_cost_list <= address_num)[0]
candidates = [reform_idx(multiple_all_candidates[idx], pred_res) for idx in idxs]
return candidates, min_address_num, address_num
def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False):
candidates = []
abduce_c = list(product(self.pseudo_label_list, repeat=len(address_idx)))

if multiple_predictions:
save_pred_res = pred_res
pred_res = flatten(pred_res)

for c in abduce_c:
candidate = pred_res.copy()
for i, idx in enumerate(address_idx):
candidate[idx] = c[i]

if multiple_predictions:
candidate = reform_idx(candidate, save_pred_res)

if check_equal(self._logic_forward(candidate, multiple_predictions), key, self.max_err):
candidates.append(candidate)
return candidates


class HWF_KB(RegKB): class HWF_KB(RegKB):
def __init__( def __init__(
@@ -456,13 +365,6 @@ class HWF_KB(RegKB):
formula = [mapping[f] for f in formula] formula = [mapping[f] for f in formula]
return round(eval(''.join(formula)), 2) return round(eval(''.join(formula)), 2)



import time

if __name__ == "__main__": if __name__ == "__main__":
t1 = time.time()
kb = add_KB(GKB_flag=True)
t2 = time.time()
print(t2 - t1)

pass

Loading…
Cancel
Save