diff --git a/abducer/abducer_base.py b/abducer/abducer_base.py index bbf00fa..6de9987 100644 --- a/abducer/abducer_base.py +++ b/abducer/abducer_base.py @@ -11,16 +11,12 @@ #================================================================# import abc -# from kb import add_KB, hwf_KB -from abducer.kb import add_KB, hwf_KB +from kb import add_KB, hwf_KB +# from abducer.kb import add_KB, hwf_KB import numpy as np from itertools import product, combinations - - - - class AbducerBase(abc.ABC): def __init__(self, kb, dist_func = 'confidence', cache = True): self.kb = kb @@ -57,10 +53,15 @@ class AbducerBase(abc.ABC): return self.confidence_dist(pred_res_prob, candidates) def get_min_cost_candidate(self, pred_res, pred_res_prob, candidates): - cost_list = self.get_cost_list(pred_res, pred_res_prob, candidates) - min_address_num = np.min(cost_list) - idxs = np.where(cost_list == min_address_num)[0] - return [candidates[idx] for idx in idxs][0] + if(len(candidates) == 0): + return [] + elif(len(candidates) == 1): + return candidates[0] + else: + cost_list = self.get_cost_list(pred_res, pred_res_prob, candidates) + min_address_num = np.min(cost_list) + idxs = np.where(cost_list == min_address_num)[0] + return [candidates[idx] for idx in idxs][0] def abduce(self, data, max_address_num = -1, require_more_address = 0): pred_res, pred_res_prob, ans = data @@ -75,13 +76,16 @@ class AbducerBase(abc.ABC): candidate = self.get_min_cost_candidate(pred_res, pred_res_prob, candidates) return candidate - if(self.kb.base != {}): + if self.kb.GKB_flag: all_candidates = self.kb.get_candidates(ans, len(pred_res)) - cost_list = self.hamming_dist(pred_res, all_candidates) - min_address_num = np.min(cost_list) - address_num = min(max_address_num, min_address_num + require_more_address) - idxs = np.where(cost_list <= address_num)[0] - candidates = [all_candidates[idx] for idx in idxs] + if(len(all_candidates) == 0): + return [] + else: + cost_list = self.hamming_dist(pred_res, all_candidates) + min_address_num = np.min(cost_list) + address_num = min(max_address_num, min_address_num + require_more_address) + idxs = np.where(cost_list <= address_num)[0] + candidates = [all_candidates[idx] for idx in idxs] else: candidates, min_address_num, address_num = self.get_abduce_candidates(pred_res, ans, max_address_num, require_more_address) @@ -111,8 +115,7 @@ class AbducerBase(abc.ABC): candidates = [] for address_num in range(len(pred_res) + 1): if(address_num > max_address_num): - print('No candidates found') - return None, None, None + return [], None, None if(address_num == 0): if(abs(self.kb.logic_forward(pred_res) - key) <= 1e-3): @@ -146,24 +149,26 @@ class AbducerBase(abc.ABC): if __name__ == '__main__': - kb = add_KB() + kb = add_KB(GKB_flag = True) abd = AbducerBase(kb, 'hamming') - res = abd.abduce(([1, 1, 1], None, 4), max_address_num = 2, require_more_address = 0) + res = abd.abduce(([1, 1], None, 4), max_address_num = 2, require_more_address = 0) print(res) - res = abd.abduce(([1, 1, 1], None, 4), max_address_num = 2, require_more_address = 1) + res = abd.abduce(([1, 1], None, 4), max_address_num = 2, require_more_address = 1) print(res) - res = abd.abduce(([1, 1, 1], None, 4), max_address_num = 1, require_more_address = 1) + res = abd.abduce(([1, 1], None, 4), max_address_num = 1, require_more_address = 1) print(res) - res = abd.abduce(([1, 1, 1], None, 4), max_address_num = 2, require_more_address = 0) + res = abd.abduce(([1, 1], None, 4), max_address_num = 2, require_more_address = 0) print(res) - res = abd.abduce(([1, 1, 1], None, 5), max_address_num = 2, require_more_address = 1) + res = abd.abduce(([1, 1], None, 5), max_address_num = 2, require_more_address = 1) print(res) print() kb = hwf_KB() - abd = AbducerBase(kb) + abd = AbducerBase(kb, 'hamming') res = abd.abduce((['5', '+', '2'], None, 3), max_address_num = 2, require_more_address = 0) print(res) + res = abd.abduce((['5', '+', '2'], None, 3.09), max_address_num = 2, require_more_address = 0) + print(res) res = abd.abduce((['5', '+', '2'], None, 1.67), max_address_num = 3, require_more_address = 0) print(res) res = abd.abduce((['5', '+', '3'], None, 0.33), max_address_num = 3, require_more_address = 3)