| @@ -40,7 +40,8 @@ class AbducerBase(abc.ABC): | |||||
| if self.multiple_predictions: | if self.multiple_predictions: | ||||
| pred_res_prob = flatten(pred_res_prob) | pred_res_prob = flatten(pred_res_prob) | ||||
| candidates = [flatten(c) for c in candidates] | candidates = [flatten(c) for c in candidates] | ||||
| # TODO:这里应该在类创建时就提前存好,每次都重新计算也太费时了 | |||||
| mapping = dict(zip(self.kb.pseudo_label_list, list(range(len(self.kb.pseudo_label_list))))) | mapping = dict(zip(self.kb.pseudo_label_list, list(range(len(self.kb.pseudo_label_list))))) | ||||
| candidates = [list(map(lambda x: mapping[x], c)) for c in candidates] | candidates = [list(map(lambda x: mapping[x], c)) for c in candidates] | ||||
| return confidence_dist(pred_res_prob, candidates) | return confidence_dist(pred_res_prob, candidates) | ||||
| @@ -53,15 +54,22 @@ class AbducerBase(abc.ABC): | |||||
| else: | else: | ||||
| cost_list = self._get_cost_list(pred_res, pred_res_prob, candidates) | cost_list = self._get_cost_list(pred_res, pred_res_prob, candidates) | ||||
| # TODO:这里很怪,按理argmin就行了 | |||||
| min_address_num = np.min(cost_list) | min_address_num = np.min(cost_list) | ||||
| idxs = np.where(cost_list == min_address_num)[0] | idxs = np.where(cost_list == min_address_num)[0] | ||||
| # TODO:这里也很怪,取第一个就行了吧 | |||||
| candidate = [candidates[idx] for idx in idxs][0] | candidate = [candidates[idx] for idx in idxs][0] | ||||
| return candidate | return candidate | ||||
| # TODO:这里对zoopt的使用不太对。zoopt想要求解的是,修改哪几个符号的位置(表示为01串),能得到“最好”的反绎结果,理论上能比kb._address中的枚举法搜索次数更少。 | |||||
| # TODO:而这里“最好”的定义,和不用zoopt搜索时的定义一致,目前要么'hamming'、要么'confidence' | |||||
| # TODO: 因此,zoopt的作用,有点类似融合kb(得到若干反绎结果)和abducer(选一个反绎结果)的功能。 | |||||
| # TODO:下面两个函数的score,应该是得到candidates后,调用_get_one_candidate计算?(没细看) | |||||
| # for zoopt | # for zoopt | ||||
| def _zoopt_score_multiple(self, pred_res, key, solution): | def _zoopt_score_multiple(self, pred_res, key, solution): | ||||
| all_address_flag = reform_idx(solution, pred_res) | all_address_flag = reform_idx(solution, pred_res) | ||||
| score = 0 | score = 0 | ||||
| # TODO:原版abl的score是这样计算的吗,我记得没有把若干预测结果拆开然后分别address的操作 | |||||
| for idx in range(len(pred_res)): | for idx in range(len(pred_res)): | ||||
| address_idx = [i for i, flag in enumerate(all_address_flag[idx]) if flag != 0] | address_idx = [i for i, flag in enumerate(all_address_flag[idx]) if flag != 0] | ||||
| candidate = self.address_by_idx([pred_res[idx]], key[idx], address_idx) | candidate = self.address_by_idx([pred_res[idx]], key[idx], address_idx) | ||||
| @@ -94,6 +102,8 @@ class AbducerBase(abc.ABC): | |||||
| return solution | return solution | ||||
| # TODO:cache移到kb里吧,比如_abduce_by_search里,它存的是若干反绎结果,不涉及从若干反绎结果中选一个 | |||||
| # TODO:python也有自带的用装饰器实现的缓存方法,比如functools.lru_cache、cachetools等,后面稍微调研一下和手动缓存的优劣,看看用哪个好 | |||||
| def _get_cache(self, data, max_address_num, require_more_address): | def _get_cache(self, data, max_address_num, require_more_address): | ||||
| pred_res, pred_res_prob, key = data | pred_res, pred_res_prob, key = data | ||||
| if self.multiple_predictions: | if self.multiple_predictions: | ||||