diff --git a/abl/abducer/abducer_base.py b/abl/abducer/abducer_base.py index c5c21fd..a70db14 100644 --- a/abl/abducer/abducer_base.py +++ b/abl/abducer/abducer_base.py @@ -16,17 +16,14 @@ from zoopt import Dimension, Objective, Parameter, Opt from ..utils.utils import confidence_dist, flatten, reform_idx, hamming_dist class AbducerBase(abc.ABC): - def __init__(self, kb, dist_func='confidence', zoopt=False, multiple_predictions=False, cache=True): + def __init__(self, kb, dist_func='hamming', zoopt=False, multiple_predictions=False): self.kb = kb assert dist_func == 'hamming' or dist_func == 'confidence' self.dist_func = dist_func self.zoopt = zoopt self.multiple_predictions = multiple_predictions - self.cache = cache - - if self.cache: - self.cache_min_address_num = {} - self.cache_candidates = {} + if dist_func == 'confidence': + self.mapping = dict(zip(self.kb.pseudo_label_list, list(range(len(self.kb.pseudo_label_list))))) def _get_cost_list(self, pred_res, pred_res_prob, candidates): if self.dist_func == 'hamming': @@ -40,10 +37,7 @@ class AbducerBase(abc.ABC): if self.multiple_predictions: pred_res_prob = flatten(pred_res_prob) candidates = [flatten(c) for c in candidates] - - # TODO:这里应该在类创建时就提前存好,每次都重新计算也太费时了 - mapping = dict(zip(self.kb.pseudo_label_list, list(range(len(self.kb.pseudo_label_list))))) - candidates = [list(map(lambda x: mapping[x], c)) for c in candidates] + candidates = [list(map(lambda x: self.mapping[x], c)) for c in candidates] return confidence_dist(pred_res_prob, candidates) def _get_one_candidate(self, pred_res, pred_res_prob, candidates): @@ -54,46 +48,38 @@ class AbducerBase(abc.ABC): else: cost_list = self._get_cost_list(pred_res, pred_res_prob, candidates) - # TODO:这里很怪,按理argmin就行了 - min_address_num = np.min(cost_list) - idxs = np.where(cost_list == min_address_num)[0] - # TODO:这里也很怪,取第一个就行了吧 - candidate = [candidates[idx] for idx in idxs][0] + candidate = candidates[np.argmin(cost_list)] return candidate - # TODO:这里对zoopt的使用不太对。zoopt想要求解的是,修改哪几个符号的位置(表示为01串),能得到“最好”的反绎结果,理论上能比kb._address中的枚举法搜索次数更少。 - # TODO:而这里“最好”的定义,和不用zoopt搜索时的定义一致,目前要么'hamming'、要么'confidence' - # TODO: 因此,zoopt的作用,有点类似融合kb(得到若干反绎结果)和abducer(选一个反绎结果)的功能。 - # TODO:下面两个函数的score,应该是得到candidates后,调用_get_one_candidate计算?(没细看) - # for zoopt - def _zoopt_score_multiple(self, pred_res, key, solution): - all_address_flag = reform_idx(solution, pred_res) - score = 0 - # TODO:原版abl的score是这样计算的吗,我记得没有把若干预测结果拆开然后分别address的操作 - for idx in range(len(pred_res)): - address_idx = [i for i, flag in enumerate(all_address_flag[idx]) if flag != 0] - candidate = self.address_by_idx([pred_res[idx]], key[idx], address_idx) - if len(candidate) > 0: - score += 1 - return score - - def _zoopt_address_score(self, pred_res, key, sol): + def _zoopt_address_score(self, pred_res, pred_res_prob, key, sol): if not self.multiple_predictions: - address_idx = [idx for idx, i in enumerate(sol.get_x()) if i != 0] + address_idx = np.where(sol.get_x() != 0)[0] candidates = self.address_by_idx(pred_res, key, address_idx) - return 1 if len(candidates) > 0 else 0 + if len(candidates) > 0: + return np.min(self._get_cost_list(pred_res, pred_res_prob, candidates)) + else: + return len(pred_res) else: - return self._zoopt_score_multiple(pred_res, key, sol.get_x()) - + all_address_flag = reform_idx(sol.get_x(), pred_res) + score = 0 + for idx in range(len(pred_res)): + address_idx = np.where(all_address_flag[idx] != 0)[0] + candidates = self.address_by_idx([pred_res[idx]], key[idx], address_idx) + if len(candidates) > 0: + score += np.min(self._get_cost_list(pred_res[idx], pred_res_prob[idx], candidates)) + else: + score += len(pred_res) + return -self._zoopt_score_multiple(pred_res, key, sol.get_x()) + def _constrain_address_num(self, solution, max_address_num): x = solution.get_x() return max_address_num - x.sum() - def zoopt_get_solution(self, pred_res, key, max_address_num): + def zoopt_get_solution(self, pred_res, pred_res_prob, key, max_address_num): length = len(flatten(pred_res)) dimension = Dimension(size=length, regs=[[0, 1]] * length, tys=[False] * length) objective = Objective( - lambda sol: -self._zoopt_address_score(pred_res, key, sol), + lambda sol: self._zoopt_address_score(pred_res, pred_res_prob, key, sol), dim=dimension, constraint=lambda sol: self._constrain_address_num(sol, max_address_num), ) @@ -101,31 +87,7 @@ class AbducerBase(abc.ABC): solution = Opt.min(objective, parameter).get_x() return solution - - # TODO:cache移到kb里吧,比如_abduce_by_search里,它存的是若干反绎结果,不涉及从若干反绎结果中选一个 - # TODO:python也有自带的用装饰器实现的缓存方法,比如functools.lru_cache、cachetools等,后面稍微调研一下和手动缓存的优劣,看看用哪个好 - def _get_cache(self, data, max_address_num, require_more_address): - pred_res, pred_res_prob, key = data - if self.multiple_predictions: - pred_res = flatten(pred_res) - key = tuple(key) - if (tuple(pred_res), key) in self.cache_min_address_num: - address_num = min(max_address_num, self.cache_min_address_num[(tuple(pred_res), key)] + require_more_address) - if (tuple(pred_res), key, address_num) in self.cache_candidates: - candidates = self.cache_candidates[(tuple(pred_res), key, address_num)] - if self.zoopt: - return candidates[0] - else: - return self._get_one_candidate(pred_res, pred_res_prob, candidates) - return None - - def _set_cache(self, pred_res, key, min_address_num, address_num, candidates): - if self.multiple_predictions: - pred_res = flatten(pred_res) - key = tuple(key) - self.cache_min_address_num[(tuple(pred_res), key)] = min_address_num - self.cache_candidates[(tuple(pred_res), key, address_num)] = candidates - + def address_by_idx(self, pred_res, key, address_idx): return self.kb.address_by_idx(pred_res, key, address_idx, self.multiple_predictions) @@ -134,27 +96,17 @@ class AbducerBase(abc.ABC): if max_address_num == -1: max_address_num = len(flatten(pred_res)) - if self.cache: - candidate = self._get_cache(data, max_address_num, require_more_address) - if candidate is not None: - return candidate - if self.zoopt: - solution = self.zoopt_get_solution(pred_res, key, max_address_num) - address_idx = [idx for idx, i in enumerate(solution) if i != 0] + solution = self.zoopt_get_solution(pred_res, pred_res_prob, key, max_address_num) + address_idx = np.where(solution != 0)[0] candidates = self.address_by_idx(pred_res, key, address_idx) - address_num = int(solution.sum()) - min_address_num = address_num else: - candidates, min_address_num, address_num = self.kb.abduce_candidates( + candidates = self.kb.abduce_candidates( pred_res, key, max_address_num, require_more_address, self.multiple_predictions ) candidate = self._get_one_candidate(pred_res, pred_res_prob, candidates) - if self.cache: - self._set_cache(pred_res, key, min_address_num, address_num, candidates) - return candidate def abduce_rules(self, pred_res): @@ -283,9 +235,9 @@ if __name__ == '__main__': print(kb.consist_rule([1, '+', 1, '=', 1, 0], rules), kb.consist_rule([1, '+', 1, '=', 1, 1], rules)) print() - res = abd.abduce((consist_exs, None, [None] * len(consist_exs))) + res = abd.abduce((consist_exs, [None] * len(consist_exs), [None] * len(consist_exs))) print(res) - res = abd.abduce((inconsist_exs, None, [None] * len(inconsist_exs))) + res = abd.abduce((inconsist_exs, [None] * len(consist_exs), [None] * len(inconsist_exs))) print(res) print() diff --git a/abl/abducer/kb.py b/abl/abducer/kb.py index d656e41..2a29d12 100644 --- a/abl/abducer/kb.py +++ b/abl/abducer/kb.py @@ -17,24 +17,30 @@ import numpy as np from collections import defaultdict from itertools import product, combinations -from ..utils.utils import flatten, reform_idx, hamming_dist, check_equal +from utils.utils import flatten, reform_idx, hamming_dist, check_equal from multiprocessing import Pool +from functools import lru_cache import pyswip class KBBase(ABC): - def __init__(self, pseudo_label_list, len_list=None, GKB_flag=False, max_err=0): + def __init__(self, pseudo_label_list, len_list=None, GKB_flag=False, max_err=0):#, abduce_cache=True): self.pseudo_label_list = pseudo_label_list self.len_list = len_list self.GKB_flag = GKB_flag self.max_err = max_err + # self.abduce_cache = abduce_cache if GKB_flag: self.base = {} X, Y = self._get_GKB() for x, y in zip(X, Y): self.base.setdefault(len(x), defaultdict(list))[y].append(x) + + # if abduce_cache: + # self.cache_min_address_num = {} + # self.cache_candidates = {} # For parallel version of _get_GKB def _get_XY_list(self, args): @@ -92,21 +98,21 @@ class KBBase(ABC): def _abduce_by_GKB(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): if self.base == {}: - return [], 0, 0 + return [] if not multiple_predictions: if len(pred_res) not in self.len_list: - return [], 0, 0 + return [] all_candidates = self._find_candidate_GKB(pred_res, key) if len(all_candidates) == 0: - return [], 0, 0 + return [] else: cost_list = hamming_dist(pred_res, all_candidates) min_address_num = np.min(cost_list) address_num = min(max_address_num, min_address_num + require_more_address) idxs = np.where(cost_list <= address_num)[0] candidates = [all_candidates[idx] for idx in idxs] - return candidates, min_address_num, address_num + return candidates else: min_address_num = 0 @@ -115,10 +121,10 @@ class KBBase(ABC): for p_res, k in zip(pred_res, key): if len(p_res) not in self.len_list: - return [], 0, 0 + return [] all_candidates = self._find_candidate_GKB(p_res, k) if len(all_candidates) == 0: - return [], 0, 0 + return [] else: all_candidates_save.append(all_candidates) cost_list = hamming_dist(p_res, all_candidates) @@ -126,13 +132,31 @@ class KBBase(ABC): cost_list_save.append(cost_list) multiple_all_candidates = [flatten(c) for c in product(*all_candidates_save)] - assert len(multiple_all_candidates[0]) == len(flatten(pred_res)) multiple_cost_list = np.array([sum(cost) for cost in product(*cost_list_save)]) - assert len(multiple_all_candidates) == len(multiple_cost_list) address_num = min(max_address_num, min_address_num + require_more_address) idxs = np.where(multiple_cost_list <= address_num)[0] candidates = [reform_idx(multiple_all_candidates[idx], pred_res) for idx in idxs] - return candidates, min_address_num, address_num + return candidates + + # TODO:python也有自带的用装饰器实现的缓存方法,比如functools.lru_cache、cachetools等,后面稍微调研一下和手动缓存的优劣,看看用哪个好 + # def _get_abduce_cache(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): + # if multiple_predictions: + # pred_res = flatten(pred_res) + # key = tuple(key) + # if (tuple(pred_res), key) in self.cache_min_address_num: + # address_num = min(max_address_num, self.cache_min_address_num[(tuple(pred_res), key)] + require_more_address) + # if (tuple(pred_res), key, address_num) in self.cache_candidates: + # candidates = self.cache_candidates[(tuple(pred_res), key, address_num)] + # return candidates + # return None + + # def _set_abduce_cache(self, pred_res, key, min_address_num, address_num, candidates, multiple_predictions): + # if multiple_predictions: + # pred_res = flatten(pred_res) + # key = tuple(key) + # self.cache_min_address_num[(tuple(pred_res), key)] = min_address_num + # self.cache_candidates[(tuple(pred_res), key, address_num)] = candidates + def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False): candidates = [] @@ -166,7 +190,13 @@ class KBBase(ABC): new_candidates += candidates return new_candidates + # @lru_cache(maxsize=100) def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): + # if self.abduce_cache: + # candidates = self._get_abduce_cache(pred_res, key, max_address_num, require_more_address, multiple_predictions) + # if candidates is not None: + # return candidates + candidates = [] for address_num in range(len(flatten(pred_res)) + 1): @@ -182,15 +212,18 @@ class KBBase(ABC): break if address_num >= max_address_num: - return [], 0, 0 + return [] for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1): if address_num > max_address_num: return candidates, min_address_num, address_num - 1 new_candidates = self._address(address_num, pred_res, key, multiple_predictions) candidates += new_candidates + + # if self.abduce_cache: + # self._set_abduce_cache(pred_res, key, min_address_num, address_num, candidates, multiple_predictions) - return candidates, min_address_num, address_num + return candidates def _dict_len(self, dic): if not self.GKB_flag: @@ -346,4 +379,4 @@ class HWF_KB(RegKB): if __name__ == "__main__": - pass + pass \ No newline at end of file diff --git a/abl/framework_hed.py b/abl/framework_hed.py index 3f09ff6..d339909 100644 --- a/abl/framework_hed.py +++ b/abl/framework_hed.py @@ -150,7 +150,7 @@ def abduce_and_train(model, abducer, mapping, train_X_true, select_num): for m in mappings: pred_res = mapping_res(original_pred_res, m) max_abduce_num = 20 - solution = abducer.zoopt_get_solution(pred_res, [None] * len(pred_res), max_abduce_num) + solution = abducer.zoopt_get_solution(pred_res, [None] * len(pred_res), [None] * len(pred_res), max_abduce_num) all_address_flag = reform_idx(solution, pred_res) consistent_idx_tmp = []