From 18fc1be9fa2b25da7690182f675757bbfc78345b Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Wed, 23 Nov 2022 13:40:47 +0800 Subject: [PATCH] Speed up get_abduce_candidates --- abducer/abducer_base.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/abducer/abducer_base.py b/abducer/abducer_base.py index 83826a6..63a3c73 100644 --- a/abducer/abducer_base.py +++ b/abducer/abducer_base.py @@ -110,11 +110,13 @@ class AbducerBase(abc.ABC): address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) for address_idx in address_idx_list: for c in all_address_candidate: - pred_res_array = np.array(pred_res) - if np.count_nonzero(np.array(c) != pred_res_array[np.array(address_idx)]) == address_num: - pred_res_array[np.array(address_idx)] = c - if self.kb.logic_forward(pred_res_array) == key: - new_candidates.append(pred_res_array) + address_list = [pred_res[i] for i in address_idx] + if(sum([address_list[i] == c[i] for i in range(address_num)]) == 0): + candidate = pred_res.copy() + for i, idx in enumerate(address_idx): + candidate[idx] = c[i] + if self.kb.logic_forward(candidate) == key: + new_candidates.append(candidate) return new_candidates def get_abduce_candidates(self, pred_res, key, max_address_num, require_more_address):