diff --git a/abl/abducer/abducer_base.py b/abl/abducer/abducer_base.py index 15f989b..d33a7e6 100644 --- a/abl/abducer/abducer_base.py +++ b/abl/abducer/abducer_base.py @@ -52,13 +52,52 @@ class AbducerBase(abc.ABC): return len(pred_res) def _zoopt_address_score(self, pred_res, pred_res_prob, key, sol): - all_address_flag = reform_idx(sol.get_x(), pred_res) if nested_length(pred_res) == 1: - return self._zoopt_address_score_single(all_address_flag, pred_res, pred_res_prob, key) + return self._zoopt_address_score_single(sol.get_x(), pred_res, pred_res_prob, key) else: + all_address_flag = reform_idx(sol.get_x(), pred_res) + lefted_idx = [i for i in range(len(pred_res))] + candidate_size = [] + while lefted_idx: + temp_idx = [] + temp_idx.append(lefted_idx.pop(0)) + max_candidate_idx = [] + found = False + for idx in range(-1, len(pred_res)): + if (not idx in temp_idx) and (idx >= 0): + temp_idx.append(idx) + + pred = [] + k = [] + address_flag = [] + for idx in temp_idx: + pred.append(pred_res[idx]) + k.append(key[idx]) + address_flag += list(all_address_flag[idx]) + address_idx = np.where(np.array(address_flag) != 0)[0] + candidate = self.address_by_idx(pred, k, address_idx) + if len(candidate) == 0: + if len(temp_idx) > 1: + temp_idx.pop() + else: + if len(temp_idx) > len(max_candidate_idx): + found = True + max_candidate_idx = temp_idx.copy() + removed = [i for i in lefted_idx if i in max_candidate_idx] + + if found: + candidate_size.append(len(removed) + 1) + lefted_idx = [i for i in lefted_idx if i not in max_candidate_idx] + + candidate_size.sort() score = 0 - for idx in range(nested_length(pred_res)): - score += self._zoopt_address_score_single(all_address_flag[idx], [pred_res[idx]], [pred_res_prob[idx]], [key[idx]]) + import math + for i in range(0, len(candidate_size)): + score -= math.exp(-i) * candidate_size[i] + + # score = 0 + # for idx in range(nested_length(pred_res)): + # score += self._zoopt_address_score_single(all_address_flag[idx], [pred_res[idx]], [pred_res_prob[idx]], [key[idx]]) return score def _constrain_address_num(self, solution, max_address_num):