|
|
|
@@ -16,17 +16,14 @@ from zoopt import Dimension, Objective, Parameter, Opt |
|
|
|
from ..utils.utils import confidence_dist, flatten, reform_idx, hamming_dist |
|
|
|
|
|
|
|
class AbducerBase(abc.ABC): |
|
|
|
def __init__(self, kb, dist_func='confidence', zoopt=False, multiple_predictions=False, cache=True): |
|
|
|
def __init__(self, kb, dist_func='hamming', zoopt=False, multiple_predictions=False): |
|
|
|
self.kb = kb |
|
|
|
assert dist_func == 'hamming' or dist_func == 'confidence' |
|
|
|
self.dist_func = dist_func |
|
|
|
self.zoopt = zoopt |
|
|
|
self.multiple_predictions = multiple_predictions |
|
|
|
self.cache = cache |
|
|
|
|
|
|
|
if self.cache: |
|
|
|
self.cache_min_address_num = {} |
|
|
|
self.cache_candidates = {} |
|
|
|
if dist_func == 'confidence': |
|
|
|
self.mapping = dict(zip(self.kb.pseudo_label_list, list(range(len(self.kb.pseudo_label_list))))) |
|
|
|
|
|
|
|
def _get_cost_list(self, pred_res, pred_res_prob, candidates): |
|
|
|
if self.dist_func == 'hamming': |
|
|
|
@@ -40,10 +37,7 @@ class AbducerBase(abc.ABC): |
|
|
|
if self.multiple_predictions: |
|
|
|
pred_res_prob = flatten(pred_res_prob) |
|
|
|
candidates = [flatten(c) for c in candidates] |
|
|
|
|
|
|
|
# TODO:这里应该在类创建时就提前存好,每次都重新计算也太费时了 |
|
|
|
mapping = dict(zip(self.kb.pseudo_label_list, list(range(len(self.kb.pseudo_label_list))))) |
|
|
|
candidates = [list(map(lambda x: mapping[x], c)) for c in candidates] |
|
|
|
candidates = [list(map(lambda x: self.mapping[x], c)) for c in candidates] |
|
|
|
return confidence_dist(pred_res_prob, candidates) |
|
|
|
|
|
|
|
def _get_one_candidate(self, pred_res, pred_res_prob, candidates): |
|
|
|
@@ -54,46 +48,38 @@ class AbducerBase(abc.ABC): |
|
|
|
|
|
|
|
else: |
|
|
|
cost_list = self._get_cost_list(pred_res, pred_res_prob, candidates) |
|
|
|
# TODO:这里很怪,按理argmin就行了 |
|
|
|
min_address_num = np.min(cost_list) |
|
|
|
idxs = np.where(cost_list == min_address_num)[0] |
|
|
|
# TODO:这里也很怪,取第一个就行了吧 |
|
|
|
candidate = [candidates[idx] for idx in idxs][0] |
|
|
|
candidate = candidates[np.argmin(cost_list)] |
|
|
|
return candidate |
|
|
|
|
|
|
|
# TODO:这里对zoopt的使用不太对。zoopt想要求解的是,修改哪几个符号的位置(表示为01串),能得到“最好”的反绎结果,理论上能比kb._address中的枚举法搜索次数更少。 |
|
|
|
# TODO:而这里“最好”的定义,和不用zoopt搜索时的定义一致,目前要么'hamming'、要么'confidence' |
|
|
|
# TODO: 因此,zoopt的作用,有点类似融合kb(得到若干反绎结果)和abducer(选一个反绎结果)的功能。 |
|
|
|
# TODO:下面两个函数的score,应该是得到candidates后,调用_get_one_candidate计算?(没细看) |
|
|
|
# for zoopt |
|
|
|
def _zoopt_score_multiple(self, pred_res, key, solution): |
|
|
|
all_address_flag = reform_idx(solution, pred_res) |
|
|
|
score = 0 |
|
|
|
# TODO:原版abl的score是这样计算的吗,我记得没有把若干预测结果拆开然后分别address的操作 |
|
|
|
for idx in range(len(pred_res)): |
|
|
|
address_idx = [i for i, flag in enumerate(all_address_flag[idx]) if flag != 0] |
|
|
|
candidate = self.address_by_idx([pred_res[idx]], key[idx], address_idx) |
|
|
|
if len(candidate) > 0: |
|
|
|
score += 1 |
|
|
|
return score |
|
|
|
|
|
|
|
def _zoopt_address_score(self, pred_res, key, sol): |
|
|
|
def _zoopt_address_score(self, pred_res, pred_res_prob, key, sol): |
|
|
|
if not self.multiple_predictions: |
|
|
|
address_idx = [idx for idx, i in enumerate(sol.get_x()) if i != 0] |
|
|
|
address_idx = np.where(sol.get_x() != 0)[0] |
|
|
|
candidates = self.address_by_idx(pred_res, key, address_idx) |
|
|
|
return 1 if len(candidates) > 0 else 0 |
|
|
|
if len(candidates) > 0: |
|
|
|
return np.min(self._get_cost_list(pred_res, pred_res_prob, candidates)) |
|
|
|
else: |
|
|
|
return len(pred_res) |
|
|
|
else: |
|
|
|
return self._zoopt_score_multiple(pred_res, key, sol.get_x()) |
|
|
|
|
|
|
|
all_address_flag = reform_idx(sol.get_x(), pred_res) |
|
|
|
score = 0 |
|
|
|
for idx in range(len(pred_res)): |
|
|
|
address_idx = np.where(all_address_flag[idx] != 0)[0] |
|
|
|
candidates = self.address_by_idx([pred_res[idx]], key[idx], address_idx) |
|
|
|
if len(candidates) > 0: |
|
|
|
score += np.min(self._get_cost_list(pred_res[idx], pred_res_prob[idx], candidates)) |
|
|
|
else: |
|
|
|
score += len(pred_res) |
|
|
|
return -self._zoopt_score_multiple(pred_res, key, sol.get_x()) |
|
|
|
|
|
|
|
def _constrain_address_num(self, solution, max_address_num): |
|
|
|
x = solution.get_x() |
|
|
|
return max_address_num - x.sum() |
|
|
|
|
|
|
|
def zoopt_get_solution(self, pred_res, key, max_address_num): |
|
|
|
def zoopt_get_solution(self, pred_res, pred_res_prob, key, max_address_num): |
|
|
|
length = len(flatten(pred_res)) |
|
|
|
dimension = Dimension(size=length, regs=[[0, 1]] * length, tys=[False] * length) |
|
|
|
objective = Objective( |
|
|
|
lambda sol: -self._zoopt_address_score(pred_res, key, sol), |
|
|
|
lambda sol: self._zoopt_address_score(pred_res, pred_res_prob, key, sol), |
|
|
|
dim=dimension, |
|
|
|
constraint=lambda sol: self._constrain_address_num(sol, max_address_num), |
|
|
|
) |
|
|
|
@@ -101,31 +87,7 @@ class AbducerBase(abc.ABC): |
|
|
|
solution = Opt.min(objective, parameter).get_x() |
|
|
|
|
|
|
|
return solution |
|
|
|
|
|
|
|
# TODO:cache移到kb里吧,比如_abduce_by_search里,它存的是若干反绎结果,不涉及从若干反绎结果中选一个 |
|
|
|
# TODO:python也有自带的用装饰器实现的缓存方法,比如functools.lru_cache、cachetools等,后面稍微调研一下和手动缓存的优劣,看看用哪个好 |
|
|
|
def _get_cache(self, data, max_address_num, require_more_address): |
|
|
|
pred_res, pred_res_prob, key = data |
|
|
|
if self.multiple_predictions: |
|
|
|
pred_res = flatten(pred_res) |
|
|
|
key = tuple(key) |
|
|
|
if (tuple(pred_res), key) in self.cache_min_address_num: |
|
|
|
address_num = min(max_address_num, self.cache_min_address_num[(tuple(pred_res), key)] + require_more_address) |
|
|
|
if (tuple(pred_res), key, address_num) in self.cache_candidates: |
|
|
|
candidates = self.cache_candidates[(tuple(pred_res), key, address_num)] |
|
|
|
if self.zoopt: |
|
|
|
return candidates[0] |
|
|
|
else: |
|
|
|
return self._get_one_candidate(pred_res, pred_res_prob, candidates) |
|
|
|
return None |
|
|
|
|
|
|
|
def _set_cache(self, pred_res, key, min_address_num, address_num, candidates): |
|
|
|
if self.multiple_predictions: |
|
|
|
pred_res = flatten(pred_res) |
|
|
|
key = tuple(key) |
|
|
|
self.cache_min_address_num[(tuple(pred_res), key)] = min_address_num |
|
|
|
self.cache_candidates[(tuple(pred_res), key, address_num)] = candidates |
|
|
|
|
|
|
|
|
|
|
|
def address_by_idx(self, pred_res, key, address_idx): |
|
|
|
return self.kb.address_by_idx(pred_res, key, address_idx, self.multiple_predictions) |
|
|
|
|
|
|
|
@@ -134,27 +96,17 @@ class AbducerBase(abc.ABC): |
|
|
|
if max_address_num == -1: |
|
|
|
max_address_num = len(flatten(pred_res)) |
|
|
|
|
|
|
|
if self.cache: |
|
|
|
candidate = self._get_cache(data, max_address_num, require_more_address) |
|
|
|
if candidate is not None: |
|
|
|
return candidate |
|
|
|
|
|
|
|
if self.zoopt: |
|
|
|
solution = self.zoopt_get_solution(pred_res, key, max_address_num) |
|
|
|
address_idx = [idx for idx, i in enumerate(solution) if i != 0] |
|
|
|
solution = self.zoopt_get_solution(pred_res, pred_res_prob, key, max_address_num) |
|
|
|
address_idx = np.where(solution != 0)[0] |
|
|
|
candidates = self.address_by_idx(pred_res, key, address_idx) |
|
|
|
address_num = int(solution.sum()) |
|
|
|
min_address_num = address_num |
|
|
|
else: |
|
|
|
candidates, min_address_num, address_num = self.kb.abduce_candidates( |
|
|
|
candidates = self.kb.abduce_candidates( |
|
|
|
pred_res, key, max_address_num, require_more_address, self.multiple_predictions |
|
|
|
) |
|
|
|
|
|
|
|
candidate = self._get_one_candidate(pred_res, pred_res_prob, candidates) |
|
|
|
|
|
|
|
if self.cache: |
|
|
|
self._set_cache(pred_res, key, min_address_num, address_num, candidates) |
|
|
|
|
|
|
|
return candidate |
|
|
|
|
|
|
|
def abduce_rules(self, pred_res): |
|
|
|
@@ -283,9 +235,9 @@ if __name__ == '__main__': |
|
|
|
print(kb.consist_rule([1, '+', 1, '=', 1, 0], rules), kb.consist_rule([1, '+', 1, '=', 1, 1], rules)) |
|
|
|
print() |
|
|
|
|
|
|
|
res = abd.abduce((consist_exs, None, [None] * len(consist_exs))) |
|
|
|
res = abd.abduce((consist_exs, [None] * len(consist_exs), [None] * len(consist_exs))) |
|
|
|
print(res) |
|
|
|
res = abd.abduce((inconsist_exs, None, [None] * len(inconsist_exs))) |
|
|
|
res = abd.abduce((inconsist_exs, [None] * len(consist_exs), [None] * len(inconsist_exs))) |
|
|
|
print(res) |
|
|
|
print() |
|
|
|
|
|
|
|
|