diff --git a/abl/abducer/abducer_base.py b/abl/abducer/abducer_base.py index d52f74b..06845a3 100644 --- a/abl/abducer/abducer_base.py +++ b/abl/abducer/abducer_base.py @@ -51,7 +51,7 @@ class AbducerBase(abc.ABC): candidate = candidates[np.argmin(cost_list)] return candidate - def _get_zoopt_score(self, sol_x, pred_res, pred_res_prob, key): + def _zoopt_address_score_single(self, sol_x, pred_res, pred_res_prob, key): address_idx = np.where(sol_x != 0)[0] candidates = self.address_by_idx(pred_res, key, address_idx) if len(candidates) > 0: @@ -61,12 +61,12 @@ class AbducerBase(abc.ABC): def _zoopt_address_score(self, pred_res, pred_res_prob, key, sol): if not self.multiple_predictions: - return self._get_address_score(sol.get_x(), pred_res, pred_res_prob, key) + return self._zoopt_address_score_single(sol.get_x(), pred_res, pred_res_prob, key) else: all_address_flag = reform_idx(sol.get_x(), pred_res) score = 0 for idx in range(len(pred_res)): - score += self._get_address_score(all_address_flag[idx], pred_res[idx], pred_res_prob[idx], key) + score += self._zoopt_address_score_single(all_address_flag[idx], pred_res[idx], pred_res_prob[idx], key) return score def _constrain_address_num(self, solution, max_address_num):