diff --git a/abl/abducer/abducer_base.py b/abl/abducer/abducer_base.py index e2a16bf..d52f74b 100644 --- a/abl/abducer/abducer_base.py +++ b/abl/abducer/abducer_base.py @@ -50,26 +50,23 @@ class AbducerBase(abc.ABC): cost_list = self._get_cost_list(pred_res, pred_res_prob, candidates) candidate = candidates[np.argmin(cost_list)] return candidate - + + def _get_zoopt_score(self, sol_x, pred_res, pred_res_prob, key): + address_idx = np.where(sol_x != 0)[0] + candidates = self.address_by_idx(pred_res, key, address_idx) + if len(candidates) > 0: + return np.min(self._get_cost_list(pred_res, pred_res_prob, candidates)) + else: + return len(pred_res) + def _zoopt_address_score(self, pred_res, pred_res_prob, key, sol): if not self.multiple_predictions: - address_idx = np.where(sol.get_x() != 0)[0] - candidates = self.address_by_idx(pred_res, key, address_idx) - if len(candidates) > 0: - return np.min(self._get_cost_list(pred_res, pred_res_prob, candidates)) - else: - return len(pred_res) + return self._get_address_score(sol.get_x(), pred_res, pred_res_prob, key) else: all_address_flag = reform_idx(sol.get_x(), pred_res) score = 0 - # TODO:这个循环里,和上面if not self.multiple_predictions部分逻辑完全一样吧,应该把上面封装一下,然后下面循环里调用封装方法即可 for idx in range(len(pred_res)): - address_idx = np.where(all_address_flag[idx] != 0)[0] - candidates = self.address_by_idx([pred_res[idx]], key[idx], address_idx) - if len(candidates) > 0: - score += np.min(self._get_cost_list(pred_res[idx], pred_res_prob[idx], candidates)) - else: - score += len(pred_res[idx]) + score += self._get_address_score(all_address_flag[idx], pred_res[idx], pred_res_prob[idx], key) return score def _constrain_address_num(self, solution, max_address_num): @@ -112,10 +109,10 @@ class AbducerBase(abc.ABC): return self.kb.abduce_rules(pred_res) def batch_abduce(self, Z, Y, max_address_num=-1, require_more_address=0): - # if self.multiple_predictions: - return self.abduce((Z['cls'], Z['prob'], Y), max_address_num, require_more_address) - # else: - # return [self.abduce((z, prob, y), max_address_num, require_more_address) for z, prob, y in zip(Z['cls'], Z['prob'], Y)] + if self.multiple_predictions: + return self.abduce((Z['cls'], Z['prob'], Y), max_address_num, require_more_address) + else: + return [self.abduce((z, prob, y), max_address_num, require_more_address) for z, prob, y in zip(Z['cls'], Z['prob'], Y)] def __call__(self, Z, Y, max_address_num=-1, require_more_address=0): return self.batch_abduce(Z, Y, max_address_num, require_more_address) @@ -241,4 +238,4 @@ if __name__ == '__main__': print() abduced_rules = abd.abduce_rules(consist_exs) - print(abduced_rules) \ No newline at end of file + print(abduced_rules)