diff --git a/abducer/kb.py b/abducer/kb.py index 4bcf448..7ade7f9 100644 --- a/abducer/kb.py +++ b/abducer/kb.py @@ -36,7 +36,10 @@ class KBBase(ABC): def address(self, address_num, pred_res, key, multiple_predictions = False): new_candidates = [] - address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) + if not multiple_predictions: + address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) + else: + address_idx_list = list(combinations(list(range(len(self.flatten(pred_res)))), address_num)) for address_idx in address_idx_list: candidates = self.address_by_idx(pred_res, key, address_idx, multiple_predictions) @@ -44,10 +47,10 @@ class KBBase(ABC): return new_candidates def correct_result(self, pred_res, key): - if type(key) == int: + if type(key) != bool: return abs(self.logic_forward(pred_res) - key) <= 1e-3 else: - return self.logic_forward(pred_res) == key + return self.logic_forward(pred_res) def abduction(self, pred_res, key, max_address_num, require_more_address, multiple_predictions = False): candidates = []