|
|
|
@@ -40,7 +40,8 @@ class AbducerBase(abc.ABC): |
|
|
|
if self.multiple_predictions: |
|
|
|
pred_res_prob = flatten(pred_res_prob) |
|
|
|
candidates = [flatten(c) for c in candidates] |
|
|
|
|
|
|
|
|
|
|
|
# TODO:这里应该在类创建时就提前存好,每次都重新计算也太费时了 |
|
|
|
mapping = dict(zip(self.kb.pseudo_label_list, list(range(len(self.kb.pseudo_label_list))))) |
|
|
|
candidates = [list(map(lambda x: mapping[x], c)) for c in candidates] |
|
|
|
return confidence_dist(pred_res_prob, candidates) |
|
|
|
@@ -53,15 +54,22 @@ class AbducerBase(abc.ABC): |
|
|
|
|
|
|
|
else: |
|
|
|
cost_list = self._get_cost_list(pred_res, pred_res_prob, candidates) |
|
|
|
# TODO:这里很怪,按理argmin就行了 |
|
|
|
min_address_num = np.min(cost_list) |
|
|
|
idxs = np.where(cost_list == min_address_num)[0] |
|
|
|
# TODO:这里也很怪,取第一个就行了吧 |
|
|
|
candidate = [candidates[idx] for idx in idxs][0] |
|
|
|
return candidate |
|
|
|
|
|
|
|
# TODO:这里对zoopt的使用不太对。zoopt想要求解的是,修改哪几个符号的位置(表示为01串),能得到“最好”的反绎结果,理论上能比kb._address中的枚举法搜索次数更少。 |
|
|
|
# TODO:而这里“最好”的定义,和不用zoopt搜索时的定义一致,目前要么'hamming'、要么'confidence' |
|
|
|
# TODO: 因此,zoopt的作用,有点类似融合kb(得到若干反绎结果)和abducer(选一个反绎结果)的功能。 |
|
|
|
# TODO:下面两个函数的score,应该是得到candidates后,调用_get_one_candidate计算?(没细看) |
|
|
|
# for zoopt |
|
|
|
def _zoopt_score_multiple(self, pred_res, key, solution): |
|
|
|
all_address_flag = reform_idx(solution, pred_res) |
|
|
|
score = 0 |
|
|
|
# TODO:原版abl的score是这样计算的吗,我记得没有把若干预测结果拆开然后分别address的操作 |
|
|
|
for idx in range(len(pred_res)): |
|
|
|
address_idx = [i for i, flag in enumerate(all_address_flag[idx]) if flag != 0] |
|
|
|
candidate = self.address_by_idx([pred_res[idx]], key[idx], address_idx) |
|
|
|
@@ -94,6 +102,8 @@ class AbducerBase(abc.ABC): |
|
|
|
|
|
|
|
return solution |
|
|
|
|
|
|
|
# TODO:cache移到kb里吧,比如_abduce_by_search里,它存的是若干反绎结果,不涉及从若干反绎结果中选一个 |
|
|
|
# TODO:python也有自带的用装饰器实现的缓存方法,比如functools.lru_cache、cachetools等,后面稍微调研一下和手动缓存的优劣,看看用哪个好 |
|
|
|
def _get_cache(self, data, max_address_num, require_more_address): |
|
|
|
pred_res, pred_res_prob, key = data |
|
|
|
if self.multiple_predictions: |
|
|
|
|