diff --git a/abducer/abducer_base.py b/abducer/abducer_base.py index 83826a6..63a3c73 100644 --- a/abducer/abducer_base.py +++ b/abducer/abducer_base.py @@ -110,11 +110,13 @@ class AbducerBase(abc.ABC): address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) for address_idx in address_idx_list: for c in all_address_candidate: - pred_res_array = np.array(pred_res) - if np.count_nonzero(np.array(c) != pred_res_array[np.array(address_idx)]) == address_num: - pred_res_array[np.array(address_idx)] = c - if self.kb.logic_forward(pred_res_array) == key: - new_candidates.append(pred_res_array) + address_list = [pred_res[i] for i in address_idx] + if(sum([address_list[i] == c[i] for i in range(address_num)]) == 0): + candidate = pred_res.copy() + for i, idx in enumerate(address_idx): + candidate[idx] = c[i] + if self.kb.logic_forward(candidate) == key: + new_candidates.append(candidate) return new_candidates def get_abduce_candidates(self, pred_res, key, max_address_num, require_more_address):