From 01ac0b8dd1fce2e15c4af97a53c5892f7f7a4d4b Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Wed, 7 Dec 2022 13:36:00 +0800 Subject: [PATCH] Fix bugs for non-zoopt option --- abducer/kb.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/abducer/kb.py b/abducer/kb.py index 4bcf448..7ade7f9 100644 --- a/abducer/kb.py +++ b/abducer/kb.py @@ -36,7 +36,10 @@ class KBBase(ABC): def address(self, address_num, pred_res, key, multiple_predictions = False): new_candidates = [] - address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) + if not multiple_predictions: + address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) + else: + address_idx_list = list(combinations(list(range(len(self.flatten(pred_res)))), address_num)) for address_idx in address_idx_list: candidates = self.address_by_idx(pred_res, key, address_idx, multiple_predictions) @@ -44,10 +47,10 @@ class KBBase(ABC): return new_candidates def correct_result(self, pred_res, key): - if type(key) == int: + if type(key) != bool: return abs(self.logic_forward(pred_res) - key) <= 1e-3 else: - return self.logic_forward(pred_res) == key + return self.logic_forward(pred_res) def abduction(self, pred_res, key, max_address_num, require_more_address, multiple_predictions = False): candidates = []