| @@ -50,26 +50,23 @@ class AbducerBase(abc.ABC): | |||
| cost_list = self._get_cost_list(pred_res, pred_res_prob, candidates) | |||
| candidate = candidates[np.argmin(cost_list)] | |||
| return candidate | |||
| def _get_zoopt_score(self, sol_x, pred_res, pred_res_prob, key): | |||
| address_idx = np.where(sol_x != 0)[0] | |||
| candidates = self.address_by_idx(pred_res, key, address_idx) | |||
| if len(candidates) > 0: | |||
| return np.min(self._get_cost_list(pred_res, pred_res_prob, candidates)) | |||
| else: | |||
| return len(pred_res) | |||
| def _zoopt_address_score(self, pred_res, pred_res_prob, key, sol): | |||
| if not self.multiple_predictions: | |||
| address_idx = np.where(sol.get_x() != 0)[0] | |||
| candidates = self.address_by_idx(pred_res, key, address_idx) | |||
| if len(candidates) > 0: | |||
| return np.min(self._get_cost_list(pred_res, pred_res_prob, candidates)) | |||
| else: | |||
| return len(pred_res) | |||
| return self._get_address_score(sol.get_x(), pred_res, pred_res_prob, key) | |||
| else: | |||
| all_address_flag = reform_idx(sol.get_x(), pred_res) | |||
| score = 0 | |||
| # TODO:这个循环里,和上面if not self.multiple_predictions部分逻辑完全一样吧,应该把上面封装一下,然后下面循环里调用封装方法即可 | |||
| for idx in range(len(pred_res)): | |||
| address_idx = np.where(all_address_flag[idx] != 0)[0] | |||
| candidates = self.address_by_idx([pred_res[idx]], key[idx], address_idx) | |||
| if len(candidates) > 0: | |||
| score += np.min(self._get_cost_list(pred_res[idx], pred_res_prob[idx], candidates)) | |||
| else: | |||
| score += len(pred_res[idx]) | |||
| score += self._get_address_score(all_address_flag[idx], pred_res[idx], pred_res_prob[idx], key) | |||
| return score | |||
| def _constrain_address_num(self, solution, max_address_num): | |||
| @@ -112,10 +109,10 @@ class AbducerBase(abc.ABC): | |||
| return self.kb.abduce_rules(pred_res) | |||
| def batch_abduce(self, Z, Y, max_address_num=-1, require_more_address=0): | |||
| # if self.multiple_predictions: | |||
| return self.abduce((Z['cls'], Z['prob'], Y), max_address_num, require_more_address) | |||
| # else: | |||
| # return [self.abduce((z, prob, y), max_address_num, require_more_address) for z, prob, y in zip(Z['cls'], Z['prob'], Y)] | |||
| if self.multiple_predictions: | |||
| return self.abduce((Z['cls'], Z['prob'], Y), max_address_num, require_more_address) | |||
| else: | |||
| return [self.abduce((z, prob, y), max_address_num, require_more_address) for z, prob, y in zip(Z['cls'], Z['prob'], Y)] | |||
| def __call__(self, Z, Y, max_address_num=-1, require_more_address=0): | |||
| return self.batch_abduce(Z, Y, max_address_num, require_more_address) | |||
| @@ -241,4 +238,4 @@ if __name__ == '__main__': | |||
| print() | |||
| abduced_rules = abd.abduce_rules(consist_exs) | |||
| print(abduced_rules) | |||
| print(abduced_rules) | |||