Browse Source

Update kb.py

pull/3/head
troyyyyy GitHub 2 years ago
parent
commit
7e4102c709
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 239 additions and 134 deletions
  1. +239
    -134
      abl/abducer/kb.py

+ 239
- 134
abl/abducer/kb.py View File

@@ -31,6 +31,15 @@ class KBBase(ABC):
@abstractmethod
def logic_forward(self):
pass
def _logic_forward(self, xs, multiple_predictions=False):
if not multiple_predictions:
return self.logic_forward(xs)
else:
res = []
for x in xs:
res.append(self.logic_forward(x))
return res

@abstractmethod
def abduce_candidates(self):
@@ -40,7 +49,7 @@ class KBBase(ABC):
def address_by_idx(self):
pass

def _address(self, address_num, pred_res, key, multiple_predictions=False):
def _address(self, address_num, pred_res, key, multiple_predictions):
new_candidates = []
if not multiple_predictions:
address_idx_list = list(combinations(list(range(len(pred_res))), address_num))
@@ -52,12 +61,12 @@ class KBBase(ABC):
new_candidates += candidates
return new_candidates

def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address=0, multiple_predictions=False):
def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address, multiple_predictions):
candidates = []

for address_num in range(len(flatten(pred_res)) + 1):
if address_num == 0:
if check_equal(self.logic_forward(pred_res), key):
if check_equal(self._logic_forward(pred_res, multiple_predictions), key, self.max_err):
candidates.append(pred_res)
else:
new_candidates = self._address(address_num, pred_res, key, multiple_predictions)
@@ -88,16 +97,14 @@ class ClsKB(KBBase):
self.GKB_flag = GKB_flag
self.pseudo_label_list = pseudo_label_list
self.len_list = len_list
self.max_err = 0

if GKB_flag:
self.base = {}
X, Y = self._get_GKB()
for x, y in zip(X, Y):
self.base.setdefault(len(x), defaultdict(list))[y].append(x)
else:
self.all_address_candidate_dict = {}
for address_num in range(max(self.len_list) + 1):
self.all_address_candidate_dict[address_num] = list(product(self.pseudo_label_list, repeat=address_num))


# For parallel version of _get_GKB
def _get_XY_list(self, args):
@@ -133,33 +140,57 @@ class ClsKB(KBBase):

def abduce_candidates(self, pred_res, key, max_address_num=-1, require_more_address=0, multiple_predictions=False):
if self.GKB_flag:
# TODO: 这里有可能是multiple_predictions吗
return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address)
return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address, multiple_predictions)
else:
return self._abduce_by_search(pred_res, key, max_address_num, require_more_address, multiple_predictions)

def _abduce_by_GKB(self, pred_res, key, max_address_num, require_more_address):
if self.base == {} or len(pred_res) not in self.len_list:
return []

all_candidates = self.base[len(pred_res)][key]
def _abduce_by_GKB(self, pred_res, key, max_address_num, require_more_address, multiple_predictions):
if self.base == {}:
return [], 0, 0

if len(all_candidates) == 0:
candidates = []
min_address_num = 0
address_num = 0
if not multiple_predictions:
if len(pred_res) not in self.len_list:
return [], 0, 0
all_candidates = self.base[len(pred_res)][key]
if len(all_candidates) == 0:
return [], 0, 0
else:
cost_list = hamming_dist(pred_res, all_candidates)
min_address_num = np.min(cost_list)
address_num = min(max_address_num, min_address_num + require_more_address)
idxs = np.where(cost_list <= address_num)[0]
candidates = [all_candidates[idx] for idx in idxs]
return candidates, min_address_num, address_num
else:
cost_list = hamming_dist(pred_res, all_candidates)
min_address_num = np.min(cost_list)
min_address_num = 0
all_candidates_save = []
cost_list_save = []
for p_res, k in zip(pred_res, key):
if len(p_res) not in self.len_list:
return [], 0, 0
all_candidates = self.base[len(p_res)][k]
if len(all_candidates) == 0:
return [], 0, 0
else:
all_candidates_save.append(all_candidates)
cost_list = hamming_dist(p_res, all_candidates)
min_address_num += np.min(cost_list)
cost_list_save.append(cost_list)
multiple_all_candidates = [flatten(c) for c in product(*all_candidates_save)]
assert len(multiple_all_candidates[0]) == len(flatten(pred_res))
multiple_cost_list = np.array([sum(cost) for cost in product(*cost_list_save)])
assert len(multiple_all_candidates) == len(multiple_cost_list)
address_num = min(max_address_num, min_address_num + require_more_address)
idxs = np.where(cost_list <= address_num)[0]
candidates = [all_candidates[idx] for idx in idxs]

return candidates, min_address_num, address_num
idxs = np.where(multiple_cost_list <= address_num)[0]
candidates = [reform_idx(multiple_all_candidates[idx], pred_res) for idx in idxs]
return candidates, min_address_num, address_num

def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False):
candidates = []
abduce_c = self.all_address_candidate_dict[len(address_idx)]
abduce_c = list(product(self.pseudo_label_list, repeat=len(address_idx)))

if multiple_predictions:
save_pred_res = pred_res
@@ -173,7 +204,7 @@ class ClsKB(KBBase):
if multiple_predictions:
candidate = reform_idx(candidate, save_pred_res)

if self.logic_forward(candidate) == key:
if check_equal(self._logic_forward(candidate, multiple_predictions), key):
candidates.append(candidate)
return candidates

@@ -197,50 +228,13 @@ class add_KB(ClsKB):
def logic_forward(self, nums):
return sum(nums)

# TODO:这是个回归任务(对于y而言),在logic_forward加round变成离散的分类任务固然可行,但最好还是用RegKB吧,作为例子示范。还需要对下面的ClsKB进行修改(见TODO)
class HWF_KB(ClsKB):
def __init__(
self, GKB_flag=False, pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], len_list=[1, 3, 5, 7]
):
super().__init__(GKB_flag, pseudo_label_list, len_list)

def valid_candidate(self, formula):
if len(formula) % 2 == 0:
return False
for i in range(len(formula)):
if i % 2 == 0 and formula[i] not in ['1', '2', '3', '4', '5', '6', '7', '8', '9']:
return False
if i % 2 != 0 and formula[i] not in ['+', '-', 'times', 'div']:
return False
return True

def logic_forward(self, formula):
if not self.valid_candidate(formula):
return np.inf
mapping = {
'1': '1',
'2': '2',
'3': '3',
'4': '4',
'5': '5',
'6': '6',
'7': '7',
'8': '8',
'9': '9',
'+': '+',
'-': '-',
'times': '*',
'div': '/',
}
formula = [mapping[f] for f in formula]
return round(eval(''.join(formula)), 2)


class prolog_KB(KBBase):
def __init__(self, pseudo_label_list):
super().__init__()
self.pseudo_label_list = pseudo_label_list
self.prolog = pyswip.Prolog()
self.max_err = 0

def logic_forward(self):
pass
@@ -295,11 +289,11 @@ class add_prolog_KB(prolog_KB):
class HED_prolog_KB(prolog_KB):
def __init__(self, pseudo_label_list=[0, 1, '+', '=']):
super().__init__(pseudo_label_list)
self.prolog.consult('../examples/datasets/hed/learn_add.pl')
self.prolog.consult('./datasets/hed/learn_add.pl')

# corresponding to `con_sol is not None` in `consistent_score_mapped` within `learn_add.py`
def logic_forward(self, exs):
return len(list(self.prolog.query("abduce_consistent_insts(%s)." % exs))) != 0
return len(list(self.prolog.query("abduce_consistent_insts([%s])." % exs))) != 0

def get_query_string_need_flatten(self, pred_res, key, address_idx):
# flatten
@@ -329,93 +323,204 @@ class HED_prolog_KB(prolog_KB):
rules.append(rule.value)
return rules

# def consist_rules(self, pred_res, rules):

# TODO:这里需要修改一下这个类,原本的RegKB是对GKB而言的,现在需要和ClsKB一样同时支持GKB和非GKB。需要补充非GKB部分(可能继承_abduce_by_search就行),以及修改GKB部分_abduce_by_GKB的逻辑(原本逻辑是找与key最近的y的abduce结果,现在改成与key在一定误差范围内的y的abduce结果)
# TODO:我理解的RegKB是这样的:
# TODO:1. 对GKB而言,即_abduce_by_GKB,给定key和length,还需要一个self.max_err,返回所有与key绝对值小于max_err的abduction结果
# TODO:比如GKB里的y有[1.3, 1.49, 1.50, 1.52, 1.6],若key=1.5,max_err=1e-5,则返回[y=1.50]的abduction结果;若key=1.5,max_err=0.05,则返回所有[y=1.49, 1.50, 1.52]的abduction结果
# TODO:因此在二分查找bisect_left后,需要分别往前和往后遍历,从GKB里找符合误差的y
# TODO:self.max_err默认值取很小就行,比如HWF这类任务;但有些任务(比如法院刑期预测)的max_err需要大些,因此可以由用户自定义
# TODO:2. 对非GKB而言,估计直接用_abduce_by_search就行,check_equal那限定为数字且控制回归误差max_err
class RegKB(KBBase):
def __init__(self, GKB_flag=False, X=None, Y=None):
def __init__(self, GKB_flag=False, pseudo_label_list=None, len_list=None, max_err=1e-3):
super().__init__()
tmp_dict = {}
for x, y in zip(X, Y):
tmp_dict.setdefault(len(x), defaultdict(list))[y].append(np.array(x))

self.base = {}
for l in tmp_dict.keys():
data = sorted(list(zip(tmp_dict[l].keys(), tmp_dict[l].values())))
X = [x for y, x in data]
Y = [y for y, x in data]
self.base[l] = (X, Y)

def valid_candidate(self):
pass
self.GKB_flag = GKB_flag
self.pseudo_label_list = pseudo_label_list
self.len_list = len_list
self.max_err = max_err

if GKB_flag:
self.base = {}
X, Y = self._get_GKB()
for x, y in zip(X, Y):
self.base.setdefault(len(x), defaultdict(list))[y].append(x)

# For parallel version of _get_GKB
def _get_XY_list(self, args):
pre_x, post_x_it = args[0], args[1]
XY_list = []
for post_x in post_x_it:
x = (pre_x,) + post_x
y = self.logic_forward(x)
if y != np.inf:
XY_list.append((x, y))
return XY_list

# Parallel _get_GKB
def _get_GKB(self):
X, Y = [], []
for length in self.len_list:
arg_list = []
for pre_x in self.pseudo_label_list:
post_x_it = product(self.pseudo_label_list, repeat=length - 1)
arg_list.append((pre_x, post_x_it))
with Pool(processes=len(arg_list)) as pool:
ret_list = pool.map(self._get_XY_list, arg_list)
for XY_list in ret_list:
if len(XY_list) == 0:
continue
part_X, part_Y = zip(*XY_list)
X.extend(part_X)
Y.extend(part_Y)
return X, Y

def logic_forward(self):
pass

def _abduce_by_GKB(self, key, length=None):
if key is None:
return self.get_all_candidates()
def abduce_candidates(self, pred_res, key, max_address_num=-1, require_more_address=0, multiple_predictions=False):
if self.GKB_flag:
return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address, multiple_predictions)
else:
return self._abduce_by_search(pred_res, key, max_address_num, require_more_address, multiple_predictions)

length = self._length(length)
def _regression_find_candidate_GKB(self, pred_res, key):
potential_candidates = self.base[len(pred_res)]
key_list = sorted(potential_candidates)
key_idx = bisect.bisect_left(key_list, key)
all_candidates = []
for idx in range(key_idx - 1, 0, -1):
k = key_list[idx]
if abs(k - key) <= self.max_err:
all_candidates += potential_candidates[k]
else:
break
for idx in range(key_idx, len(key_list)):
k = key_list[idx]
if abs(k - key) <= self.max_err:
all_candidates += potential_candidates[k]
else:
break
return all_candidates
def _abduce_by_GKB(self, pred_res, key, max_address_num, require_more_address, multiple_predictions):
if self.base == {}:
return [], 0, 0

min_err = 999999
if not multiple_predictions:
if len(pred_res) not in self.len_list:
return [], 0, 0
all_candidates = self._regression_find_candidate_GKB(pred_res, key)
if len(all_candidates) == 0:
return [], 0, 0
else:
cost_list = hamming_dist(pred_res, all_candidates)
min_address_num = np.min(cost_list)
address_num = min(max_address_num, min_address_num + require_more_address)
idxs = np.where(cost_list <= address_num)[0]
candidates = [all_candidates[idx] for idx in idxs]
return candidates, min_address_num, address_num
else:
min_address_num = 0
all_candidates_save = []
cost_list_save = []
for p_res, k in zip(pred_res, key):
if len(p_res) not in self.len_list:
return [], 0, 0
all_candidates = self._regression_find_candidate_GKB(p_res, k)
if len(all_candidates) == 0:
return [], 0, 0
else:
all_candidates_save.append(all_candidates)
cost_list = hamming_dist(p_res, all_candidates)
min_address_num += np.min(cost_list)
cost_list_save.append(cost_list)
multiple_all_candidates = [flatten(c) for c in product(*all_candidates_save)]
assert len(multiple_all_candidates[0]) == len(flatten(pred_res))
multiple_cost_list = np.array([sum(cost) for cost in product(*cost_list_save)])
assert len(multiple_all_candidates) == len(multiple_cost_list)
address_num = min(max_address_num, min_address_num + require_more_address)
idxs = np.where(multiple_cost_list <= address_num)[0]
candidates = [reform_idx(multiple_all_candidates[idx], pred_res) for idx in idxs]
return candidates, min_address_num, address_num
def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False):
candidates = []
for l in length:
X, Y = self.base[l]

idx = bisect.bisect_left(Y, key)
begin = max(0, idx - 1)
end = min(idx + 2, len(X))

for idx in range(begin, end):
err = abs(Y[idx] - key)
if abs(err - min_err) < 1e-9:
candidates.extend(X[idx])
elif err < min_err:
candidates = copy.deepcopy(X[idx])
min_err = err
return candidates
abduce_c = list(product(self.pseudo_label_list, repeat=len(address_idx)))

def get_all_candidates(self):
return sum([sum(D[0], []) for D in self.base.values()], [])
if multiple_predictions:
save_pred_res = pred_res
pred_res = flatten(pred_res)

for c in abduce_c:
candidate = pred_res.copy()
for i, idx in enumerate(address_idx):
candidate[idx] = c[i]

if multiple_predictions:
candidate = reform_idx(candidate, save_pred_res)

if check_equal(self._logic_forward(candidate, multiple_predictions), key, self.max_err):
candidates.append(candidate)
return candidates
def _dict_len(self, dic):
if not self.GKB_flag:
return 0
else:
return sum(len(c) for c in dic.values())

def __len__(self):
return sum([sum(len(x) for x in D[0]) for D in self.base.values()])
if not self.GKB_flag:
return 0
else:
return sum(self._dict_len(v) for v in self.base.values())

class HWF_KB(RegKB):
def __init__(
self, GKB_flag=False,
pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'],
len_list=[1, 3, 5, 7],
max_err=1e-3
):
super().__init__(GKB_flag, pseudo_label_list, len_list, max_err)

def valid_candidate(self, formula):
if len(formula) % 2 == 0:
return False
for i in range(len(formula)):
if i % 2 == 0 and formula[i] not in ['1', '2', '3', '4', '5', '6', '7', '8', '9']:
return False
if i % 2 != 0 and formula[i] not in ['+', '-', 'times', 'div']:
return False
return True

def logic_forward(self, formula):
if not self.valid_candidate(formula):
return np.inf
mapping = {
'1': '1',
'2': '2',
'3': '3',
'4': '4',
'5': '5',
'6': '6',
'7': '7',
'8': '8',
'9': '9',
'+': '+',
'-': '-',
'times': '*',
'div': '/',
}
formula = [mapping[f] for f in formula]
return round(eval(''.join(formula)), 2)


import time

if __name__ == "__main__":
t1 = time.time()
kb = HWF_KB(True)
kb = add_KB(True)
t2 = time.time()
print(t2 - t1)

# X = ["1+1", "0+1", "1+0", "2+0", "1+0+1"]
# Y = [2, 1, 1, 2, 2]
# kb = ClsKB(X, Y)
# print('len(kb):', len(kb))
# res = kb.get_candidates(2, 5)
# print(res)
# res = kb.get_candidates(2, 3)
# print(res)
# res = kb.get_candidates(None)
# print(res)
# print()

# X = ["1+1", "0+1", "1+0", "2+0", "1+0.5", "0.75+0.75"]
# Y = [2, 1, 1, 2, 1.5, 1.5]
# kb = RegKB(X, Y)
# print('len(kb):', len(kb))
# res = kb.get_candidates(1.6)
# print(res)
# res = kb.get_candidates(1.6, length = 9)
# print(res)
# res = kb.get_candidates(None)
# print(res)

Loading…
Cancel
Save