Browse Source

Update abducer_base.py and related files

pull/3/head
troyyyyy 2 years ago
parent
commit
19f29e4088
3 changed files with 78 additions and 93 deletions
  1. +30
    -78
      abl/abducer/abducer_base.py
  2. +47
    -14
      abl/abducer/kb.py
  3. +1
    -1
      abl/framework_hed.py

+ 30
- 78
abl/abducer/abducer_base.py View File

@@ -16,17 +16,14 @@ from zoopt import Dimension, Objective, Parameter, Opt
from ..utils.utils import confidence_dist, flatten, reform_idx, hamming_dist

class AbducerBase(abc.ABC):
def __init__(self, kb, dist_func='confidence', zoopt=False, multiple_predictions=False, cache=True):
def __init__(self, kb, dist_func='hamming', zoopt=False, multiple_predictions=False):
self.kb = kb
assert dist_func == 'hamming' or dist_func == 'confidence'
self.dist_func = dist_func
self.zoopt = zoopt
self.multiple_predictions = multiple_predictions
self.cache = cache

if self.cache:
self.cache_min_address_num = {}
self.cache_candidates = {}
if dist_func == 'confidence':
self.mapping = dict(zip(self.kb.pseudo_label_list, list(range(len(self.kb.pseudo_label_list)))))

def _get_cost_list(self, pred_res, pred_res_prob, candidates):
if self.dist_func == 'hamming':
@@ -40,10 +37,7 @@ class AbducerBase(abc.ABC):
if self.multiple_predictions:
pred_res_prob = flatten(pred_res_prob)
candidates = [flatten(c) for c in candidates]
# TODO:这里应该在类创建时就提前存好,每次都重新计算也太费时了
mapping = dict(zip(self.kb.pseudo_label_list, list(range(len(self.kb.pseudo_label_list)))))
candidates = [list(map(lambda x: mapping[x], c)) for c in candidates]
candidates = [list(map(lambda x: self.mapping[x], c)) for c in candidates]
return confidence_dist(pred_res_prob, candidates)

def _get_one_candidate(self, pred_res, pred_res_prob, candidates):
@@ -54,46 +48,38 @@ class AbducerBase(abc.ABC):
else:
cost_list = self._get_cost_list(pred_res, pred_res_prob, candidates)
# TODO:这里很怪,按理argmin就行了
min_address_num = np.min(cost_list)
idxs = np.where(cost_list == min_address_num)[0]
# TODO:这里也很怪,取第一个就行了吧
candidate = [candidates[idx] for idx in idxs][0]
candidate = candidates[np.argmin(cost_list)]
return candidate

# TODO:这里对zoopt的使用不太对。zoopt想要求解的是,修改哪几个符号的位置(表示为01串),能得到“最好”的反绎结果,理论上能比kb._address中的枚举法搜索次数更少。
# TODO:而这里“最好”的定义,和不用zoopt搜索时的定义一致,目前要么'hamming'、要么'confidence'
# TODO: 因此,zoopt的作用,有点类似融合kb(得到若干反绎结果)和abducer(选一个反绎结果)的功能。
# TODO:下面两个函数的score,应该是得到candidates后,调用_get_one_candidate计算?(没细看)
# for zoopt
def _zoopt_score_multiple(self, pred_res, key, solution):
all_address_flag = reform_idx(solution, pred_res)
score = 0
# TODO:原版abl的score是这样计算的吗,我记得没有把若干预测结果拆开然后分别address的操作
for idx in range(len(pred_res)):
address_idx = [i for i, flag in enumerate(all_address_flag[idx]) if flag != 0]
candidate = self.address_by_idx([pred_res[idx]], key[idx], address_idx)
if len(candidate) > 0:
score += 1
return score

def _zoopt_address_score(self, pred_res, key, sol):
def _zoopt_address_score(self, pred_res, pred_res_prob, key, sol):
if not self.multiple_predictions:
address_idx = [idx for idx, i in enumerate(sol.get_x()) if i != 0]
address_idx = np.where(sol.get_x() != 0)[0]
candidates = self.address_by_idx(pred_res, key, address_idx)
return 1 if len(candidates) > 0 else 0
if len(candidates) > 0:
return np.min(self._get_cost_list(pred_res, pred_res_prob, candidates))
else:
return len(pred_res)
else:
return self._zoopt_score_multiple(pred_res, key, sol.get_x())

all_address_flag = reform_idx(sol.get_x(), pred_res)
score = 0
for idx in range(len(pred_res)):
address_idx = np.where(all_address_flag[idx] != 0)[0]
candidates = self.address_by_idx([pred_res[idx]], key[idx], address_idx)
if len(candidates) > 0:
score += np.min(self._get_cost_list(pred_res[idx], pred_res_prob[idx], candidates))
else:
score += len(pred_res)
return -self._zoopt_score_multiple(pred_res, key, sol.get_x())
def _constrain_address_num(self, solution, max_address_num):
x = solution.get_x()
return max_address_num - x.sum()

def zoopt_get_solution(self, pred_res, key, max_address_num):
def zoopt_get_solution(self, pred_res, pred_res_prob, key, max_address_num):
length = len(flatten(pred_res))
dimension = Dimension(size=length, regs=[[0, 1]] * length, tys=[False] * length)
objective = Objective(
lambda sol: -self._zoopt_address_score(pred_res, key, sol),
lambda sol: self._zoopt_address_score(pred_res, pred_res_prob, key, sol),
dim=dimension,
constraint=lambda sol: self._constrain_address_num(sol, max_address_num),
)
@@ -101,31 +87,7 @@ class AbducerBase(abc.ABC):
solution = Opt.min(objective, parameter).get_x()

return solution

# TODO:cache移到kb里吧,比如_abduce_by_search里,它存的是若干反绎结果,不涉及从若干反绎结果中选一个
# TODO:python也有自带的用装饰器实现的缓存方法,比如functools.lru_cache、cachetools等,后面稍微调研一下和手动缓存的优劣,看看用哪个好
def _get_cache(self, data, max_address_num, require_more_address):
pred_res, pred_res_prob, key = data
if self.multiple_predictions:
pred_res = flatten(pred_res)
key = tuple(key)
if (tuple(pred_res), key) in self.cache_min_address_num:
address_num = min(max_address_num, self.cache_min_address_num[(tuple(pred_res), key)] + require_more_address)
if (tuple(pred_res), key, address_num) in self.cache_candidates:
candidates = self.cache_candidates[(tuple(pred_res), key, address_num)]
if self.zoopt:
return candidates[0]
else:
return self._get_one_candidate(pred_res, pred_res_prob, candidates)
return None

def _set_cache(self, pred_res, key, min_address_num, address_num, candidates):
if self.multiple_predictions:
pred_res = flatten(pred_res)
key = tuple(key)
self.cache_min_address_num[(tuple(pred_res), key)] = min_address_num
self.cache_candidates[(tuple(pred_res), key, address_num)] = candidates
def address_by_idx(self, pred_res, key, address_idx):
return self.kb.address_by_idx(pred_res, key, address_idx, self.multiple_predictions)

@@ -134,27 +96,17 @@ class AbducerBase(abc.ABC):
if max_address_num == -1:
max_address_num = len(flatten(pred_res))

if self.cache:
candidate = self._get_cache(data, max_address_num, require_more_address)
if candidate is not None:
return candidate

if self.zoopt:
solution = self.zoopt_get_solution(pred_res, key, max_address_num)
address_idx = [idx for idx, i in enumerate(solution) if i != 0]
solution = self.zoopt_get_solution(pred_res, pred_res_prob, key, max_address_num)
address_idx = np.where(solution != 0)[0]
candidates = self.address_by_idx(pred_res, key, address_idx)
address_num = int(solution.sum())
min_address_num = address_num
else:
candidates, min_address_num, address_num = self.kb.abduce_candidates(
candidates = self.kb.abduce_candidates(
pred_res, key, max_address_num, require_more_address, self.multiple_predictions
)

candidate = self._get_one_candidate(pred_res, pred_res_prob, candidates)

if self.cache:
self._set_cache(pred_res, key, min_address_num, address_num, candidates)

return candidate

def abduce_rules(self, pred_res):
@@ -283,9 +235,9 @@ if __name__ == '__main__':
print(kb.consist_rule([1, '+', 1, '=', 1, 0], rules), kb.consist_rule([1, '+', 1, '=', 1, 1], rules))
print()

res = abd.abduce((consist_exs, None, [None] * len(consist_exs)))
res = abd.abduce((consist_exs, [None] * len(consist_exs), [None] * len(consist_exs)))
print(res)
res = abd.abduce((inconsist_exs, None, [None] * len(inconsist_exs)))
res = abd.abduce((inconsist_exs, [None] * len(consist_exs), [None] * len(inconsist_exs)))
print(res)
print()



+ 47
- 14
abl/abducer/kb.py View File

@@ -17,24 +17,30 @@ import numpy as np

from collections import defaultdict
from itertools import product, combinations
from ..utils.utils import flatten, reform_idx, hamming_dist, check_equal
from utils.utils import flatten, reform_idx, hamming_dist, check_equal

from multiprocessing import Pool

from functools import lru_cache
import pyswip

class KBBase(ABC):
def __init__(self, pseudo_label_list, len_list=None, GKB_flag=False, max_err=0):
def __init__(self, pseudo_label_list, len_list=None, GKB_flag=False, max_err=0):#, abduce_cache=True):
self.pseudo_label_list = pseudo_label_list
self.len_list = len_list
self.GKB_flag = GKB_flag
self.max_err = max_err
# self.abduce_cache = abduce_cache

if GKB_flag:
self.base = {}
X, Y = self._get_GKB()
for x, y in zip(X, Y):
self.base.setdefault(len(x), defaultdict(list))[y].append(x)
# if abduce_cache:
# self.cache_min_address_num = {}
# self.cache_candidates = {}

# For parallel version of _get_GKB
def _get_XY_list(self, args):
@@ -92,21 +98,21 @@ class KBBase(ABC):
def _abduce_by_GKB(self, pred_res, key, max_address_num, require_more_address, multiple_predictions):
if self.base == {}:
return [], 0, 0
return []

if not multiple_predictions:
if len(pred_res) not in self.len_list:
return [], 0, 0
return []
all_candidates = self._find_candidate_GKB(pred_res, key)
if len(all_candidates) == 0:
return [], 0, 0
return []
else:
cost_list = hamming_dist(pred_res, all_candidates)
min_address_num = np.min(cost_list)
address_num = min(max_address_num, min_address_num + require_more_address)
idxs = np.where(cost_list <= address_num)[0]
candidates = [all_candidates[idx] for idx in idxs]
return candidates, min_address_num, address_num
return candidates
else:
min_address_num = 0
@@ -115,10 +121,10 @@ class KBBase(ABC):
for p_res, k in zip(pred_res, key):
if len(p_res) not in self.len_list:
return [], 0, 0
return []
all_candidates = self._find_candidate_GKB(p_res, k)
if len(all_candidates) == 0:
return [], 0, 0
return []
else:
all_candidates_save.append(all_candidates)
cost_list = hamming_dist(p_res, all_candidates)
@@ -126,13 +132,31 @@ class KBBase(ABC):
cost_list_save.append(cost_list)
multiple_all_candidates = [flatten(c) for c in product(*all_candidates_save)]
assert len(multiple_all_candidates[0]) == len(flatten(pred_res))
multiple_cost_list = np.array([sum(cost) for cost in product(*cost_list_save)])
assert len(multiple_all_candidates) == len(multiple_cost_list)
address_num = min(max_address_num, min_address_num + require_more_address)
idxs = np.where(multiple_cost_list <= address_num)[0]
candidates = [reform_idx(multiple_all_candidates[idx], pred_res) for idx in idxs]
return candidates, min_address_num, address_num
return candidates
# TODO:python也有自带的用装饰器实现的缓存方法,比如functools.lru_cache、cachetools等,后面稍微调研一下和手动缓存的优劣,看看用哪个好
# def _get_abduce_cache(self, pred_res, key, max_address_num, require_more_address, multiple_predictions):
# if multiple_predictions:
# pred_res = flatten(pred_res)
# key = tuple(key)
# if (tuple(pred_res), key) in self.cache_min_address_num:
# address_num = min(max_address_num, self.cache_min_address_num[(tuple(pred_res), key)] + require_more_address)
# if (tuple(pred_res), key, address_num) in self.cache_candidates:
# candidates = self.cache_candidates[(tuple(pred_res), key, address_num)]
# return candidates
# return None

# def _set_abduce_cache(self, pred_res, key, min_address_num, address_num, candidates, multiple_predictions):
# if multiple_predictions:
# pred_res = flatten(pred_res)
# key = tuple(key)
# self.cache_min_address_num[(tuple(pred_res), key)] = min_address_num
# self.cache_candidates[(tuple(pred_res), key, address_num)] = candidates
def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False):
candidates = []
@@ -166,7 +190,13 @@ class KBBase(ABC):
new_candidates += candidates
return new_candidates

# @lru_cache(maxsize=100)
def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address, multiple_predictions):
# if self.abduce_cache:
# candidates = self._get_abduce_cache(pred_res, key, max_address_num, require_more_address, multiple_predictions)
# if candidates is not None:
# return candidates

candidates = []

for address_num in range(len(flatten(pred_res)) + 1):
@@ -182,15 +212,18 @@ class KBBase(ABC):
break

if address_num >= max_address_num:
return [], 0, 0
return []

for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1):
if address_num > max_address_num:
return candidates, min_address_num, address_num - 1
new_candidates = self._address(address_num, pred_res, key, multiple_predictions)
candidates += new_candidates
# if self.abduce_cache:
# self._set_abduce_cache(pred_res, key, min_address_num, address_num, candidates, multiple_predictions)

return candidates, min_address_num, address_num
return candidates

def _dict_len(self, dic):
if not self.GKB_flag:
@@ -346,4 +379,4 @@ class HWF_KB(RegKB):


if __name__ == "__main__":
pass
pass

+ 1
- 1
abl/framework_hed.py View File

@@ -150,7 +150,7 @@ def abduce_and_train(model, abducer, mapping, train_X_true, select_num):
for m in mappings:
pred_res = mapping_res(original_pred_res, m)
max_abduce_num = 20
solution = abducer.zoopt_get_solution(pred_res, [None] * len(pred_res), max_abduce_num)
solution = abducer.zoopt_get_solution(pred_res, [None] * len(pred_res), [None] * len(pred_res), max_abduce_num)
all_address_flag = reform_idx(solution, pred_res)

consistent_idx_tmp = []


Loading…
Cancel
Save