| @@ -11,16 +11,12 @@ | |||||
| #================================================================# | #================================================================# | ||||
| import abc | import abc | ||||
| # from kb import add_KB, hwf_KB | |||||
| from abducer.kb import add_KB, hwf_KB | |||||
| from kb import add_KB, hwf_KB | |||||
| # from abducer.kb import add_KB, hwf_KB | |||||
| import numpy as np | import numpy as np | ||||
| from itertools import product, combinations | from itertools import product, combinations | ||||
| class AbducerBase(abc.ABC): | class AbducerBase(abc.ABC): | ||||
| def __init__(self, kb, dist_func = 'confidence', cache = True): | def __init__(self, kb, dist_func = 'confidence', cache = True): | ||||
| self.kb = kb | self.kb = kb | ||||
| @@ -57,10 +53,15 @@ class AbducerBase(abc.ABC): | |||||
| return self.confidence_dist(pred_res_prob, candidates) | return self.confidence_dist(pred_res_prob, candidates) | ||||
| def get_min_cost_candidate(self, pred_res, pred_res_prob, candidates): | def get_min_cost_candidate(self, pred_res, pred_res_prob, candidates): | ||||
| cost_list = self.get_cost_list(pred_res, pred_res_prob, candidates) | |||||
| min_address_num = np.min(cost_list) | |||||
| idxs = np.where(cost_list == min_address_num)[0] | |||||
| return [candidates[idx] for idx in idxs][0] | |||||
| if(len(candidates) == 0): | |||||
| return [] | |||||
| elif(len(candidates) == 1): | |||||
| return candidates[0] | |||||
| else: | |||||
| cost_list = self.get_cost_list(pred_res, pred_res_prob, candidates) | |||||
| min_address_num = np.min(cost_list) | |||||
| idxs = np.where(cost_list == min_address_num)[0] | |||||
| return [candidates[idx] for idx in idxs][0] | |||||
| def abduce(self, data, max_address_num = -1, require_more_address = 0): | def abduce(self, data, max_address_num = -1, require_more_address = 0): | ||||
| pred_res, pred_res_prob, ans = data | pred_res, pred_res_prob, ans = data | ||||
| @@ -75,13 +76,16 @@ class AbducerBase(abc.ABC): | |||||
| candidate = self.get_min_cost_candidate(pred_res, pred_res_prob, candidates) | candidate = self.get_min_cost_candidate(pred_res, pred_res_prob, candidates) | ||||
| return candidate | return candidate | ||||
| if(self.kb.base != {}): | |||||
| if self.kb.GKB_flag: | |||||
| all_candidates = self.kb.get_candidates(ans, len(pred_res)) | all_candidates = self.kb.get_candidates(ans, len(pred_res)) | ||||
| cost_list = self.hamming_dist(pred_res, all_candidates) | |||||
| min_address_num = np.min(cost_list) | |||||
| address_num = min(max_address_num, min_address_num + require_more_address) | |||||
| idxs = np.where(cost_list <= address_num)[0] | |||||
| candidates = [all_candidates[idx] for idx in idxs] | |||||
| if(len(all_candidates) == 0): | |||||
| return [] | |||||
| else: | |||||
| cost_list = self.hamming_dist(pred_res, all_candidates) | |||||
| min_address_num = np.min(cost_list) | |||||
| address_num = min(max_address_num, min_address_num + require_more_address) | |||||
| idxs = np.where(cost_list <= address_num)[0] | |||||
| candidates = [all_candidates[idx] for idx in idxs] | |||||
| else: | else: | ||||
| candidates, min_address_num, address_num = self.get_abduce_candidates(pred_res, ans, max_address_num, require_more_address) | candidates, min_address_num, address_num = self.get_abduce_candidates(pred_res, ans, max_address_num, require_more_address) | ||||
| @@ -111,8 +115,7 @@ class AbducerBase(abc.ABC): | |||||
| candidates = [] | candidates = [] | ||||
| for address_num in range(len(pred_res) + 1): | for address_num in range(len(pred_res) + 1): | ||||
| if(address_num > max_address_num): | if(address_num > max_address_num): | ||||
| print('No candidates found') | |||||
| return None, None, None | |||||
| return [], None, None | |||||
| if(address_num == 0): | if(address_num == 0): | ||||
| if(abs(self.kb.logic_forward(pred_res) - key) <= 1e-3): | if(abs(self.kb.logic_forward(pred_res) - key) <= 1e-3): | ||||
| @@ -146,24 +149,26 @@ class AbducerBase(abc.ABC): | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| kb = add_KB() | |||||
| kb = add_KB(GKB_flag = True) | |||||
| abd = AbducerBase(kb, 'hamming') | abd = AbducerBase(kb, 'hamming') | ||||
| res = abd.abduce(([1, 1, 1], None, 4), max_address_num = 2, require_more_address = 0) | |||||
| res = abd.abduce(([1, 1], None, 4), max_address_num = 2, require_more_address = 0) | |||||
| print(res) | print(res) | ||||
| res = abd.abduce(([1, 1, 1], None, 4), max_address_num = 2, require_more_address = 1) | |||||
| res = abd.abduce(([1, 1], None, 4), max_address_num = 2, require_more_address = 1) | |||||
| print(res) | print(res) | ||||
| res = abd.abduce(([1, 1, 1], None, 4), max_address_num = 1, require_more_address = 1) | |||||
| res = abd.abduce(([1, 1], None, 4), max_address_num = 1, require_more_address = 1) | |||||
| print(res) | print(res) | ||||
| res = abd.abduce(([1, 1, 1], None, 4), max_address_num = 2, require_more_address = 0) | |||||
| res = abd.abduce(([1, 1], None, 4), max_address_num = 2, require_more_address = 0) | |||||
| print(res) | print(res) | ||||
| res = abd.abduce(([1, 1, 1], None, 5), max_address_num = 2, require_more_address = 1) | |||||
| res = abd.abduce(([1, 1], None, 5), max_address_num = 2, require_more_address = 1) | |||||
| print(res) | print(res) | ||||
| print() | print() | ||||
| kb = hwf_KB() | kb = hwf_KB() | ||||
| abd = AbducerBase(kb) | |||||
| abd = AbducerBase(kb, 'hamming') | |||||
| res = abd.abduce((['5', '+', '2'], None, 3), max_address_num = 2, require_more_address = 0) | res = abd.abduce((['5', '+', '2'], None, 3), max_address_num = 2, require_more_address = 0) | ||||
| print(res) | print(res) | ||||
| res = abd.abduce((['5', '+', '2'], None, 3.09), max_address_num = 2, require_more_address = 0) | |||||
| print(res) | |||||
| res = abd.abduce((['5', '+', '2'], None, 1.67), max_address_num = 3, require_more_address = 0) | res = abd.abduce((['5', '+', '2'], None, 1.67), max_address_num = 3, require_more_address = 0) | ||||
| print(res) | print(res) | ||||
| res = abd.abduce((['5', '+', '3'], None, 0.33), max_address_num = 3, require_more_address = 3) | res = abd.abduce((['5', '+', '3'], None, 0.33), max_address_num = 3, require_more_address = 3) | ||||