| @@ -38,34 +38,29 @@ def confidence_dist(A, B): | |||||
| class AbducerBase(abc.ABC): | class AbducerBase(abc.ABC): | ||||
| def __init__(self, kb, dist_func = 'hamming', pred_res_parse = None, cache = True): | |||||
| def __init__(self, kb, dist_func = 'confidence', cache = True): | |||||
| self.kb = kb | self.kb = kb | ||||
| if(dist_func == 'hamming'): | |||||
| self.dist_func = hamming_dist | |||||
| elif(dist_func == 'confidence'): | |||||
| self.dist_func = confidence_dist | |||||
| if pred_res_parse is None: | |||||
| if(dist_func == 'hamming'): | |||||
| pred_res_parse = lambda x : x["cls"] | |||||
| elif dist_func == 'confidence': | |||||
| pred_res_parse = lambda x : x["prob"] | |||||
| self.pred_res_parse = pred_res_parse | |||||
| assert(dist_func == 'hamming' or dist_func == 'confidence') | |||||
| self.dist_func = dist_func | |||||
| self.cache = cache | self.cache = cache | ||||
| self.cache_min_address_num = {} | self.cache_min_address_num = {} | ||||
| self.cache_candidates = {} | self.cache_candidates = {} | ||||
| def get_min_cost_candidate(self, pred_res, candidates): | |||||
| cost_list = self.dist_func(pred_res, candidates) | |||||
| def get_cost_list(self, pred_res, pred_res_prob, candidates): | |||||
| if(self.dist_func == 'hamming'): | |||||
| return hamming_dist(pred_res, candidates) | |||||
| elif(self.dist_func == 'confidence'): | |||||
| return confidence_dist(pred_res_prob, candidates) | |||||
| def get_min_cost_candidate(self, pred_res, pred_res_prob, candidates): | |||||
| cost_list = self.get_cost_list(pred_res, pred_res_prob, candidates) | |||||
| min_address_num = np.min(cost_list) | min_address_num = np.min(cost_list) | ||||
| idxs = np.where(cost_list == min_address_num)[0] | idxs = np.where(cost_list == min_address_num)[0] | ||||
| return [candidates[idx] for idx in idxs][0] | return [candidates[idx] for idx in idxs][0] | ||||
| def abduce(self, data, max_address_num = -1, require_more_address = 0): | def abduce(self, data, max_address_num = -1, require_more_address = 0): | ||||
| pred_res, ans = data | |||||
| pred_res, pred_res_prob, ans = data | |||||
| if max_address_num == -1: | if max_address_num == -1: | ||||
| max_address_num = len(pred_res) | max_address_num = len(pred_res) | ||||
| @@ -74,12 +69,12 @@ class AbducerBase(abc.ABC): | |||||
| if((tuple(pred_res), ans, address_num) in self.cache_candidates): | if((tuple(pred_res), ans, address_num) in self.cache_candidates): | ||||
| # print('cached') | # print('cached') | ||||
| candidates = self.cache_candidates[(tuple(pred_res), ans, address_num)] | candidates = self.cache_candidates[(tuple(pred_res), ans, address_num)] | ||||
| candidates = self.get_min_cost_candidate(pred_res, candidates) | |||||
| candidates = self.get_min_cost_candidate(pred_res, pred_res_prob, candidates) | |||||
| return candidates | return candidates | ||||
| if(self.kb.base != {}): | if(self.kb.base != {}): | ||||
| all_candidates = self.kb.get_candidates(ans, len(pred_res)) | all_candidates = self.kb.get_candidates(ans, len(pred_res)) | ||||
| cost_list = self.dist_func(pred_res, all_candidates) | |||||
| cost_list = self.get_cost_list(pred_res, pred_res_prob, all_candidates) | |||||
| min_address_num = np.min(cost_list) | min_address_num = np.min(cost_list) | ||||
| address_num = min(max_address_num, min_address_num + require_more_address) | address_num = min(max_address_num, min_address_num + require_more_address) | ||||
| idxs = np.where(cost_list <= address_num)[0] | idxs = np.where(cost_list <= address_num)[0] | ||||
| @@ -92,7 +87,7 @@ class AbducerBase(abc.ABC): | |||||
| self.cache_min_address_num[(tuple(pred_res), ans)] = min_address_num | self.cache_min_address_num[(tuple(pred_res), ans)] = min_address_num | ||||
| self.cache_candidates[(tuple(pred_res), ans, address_num)] = candidates | self.cache_candidates[(tuple(pred_res), ans, address_num)] = candidates | ||||
| candidates = self.get_min_cost_candidate(pred_res, candidates) | |||||
| candidates = self.get_min_cost_candidate(pred_res, pred_res_prob, candidates) | |||||
| return candidates | return candidates | ||||
| def address(self, address_num, pred_res, key): | def address(self, address_num, pred_res, key): | ||||
| @@ -136,18 +131,18 @@ class AbducerBase(abc.ABC): | |||||
| return candidates, min_address_num, address_num | return candidates, min_address_num, address_num | ||||
| def batch_abduce(self, Y, C, max_address_num = -1, require_more_address = 0): | |||||
| def batch_abduce(self, Z, Y, max_address_num = -1, require_more_address = 0): | |||||
| return [ | return [ | ||||
| self.abduce((y, c), max_address_num, require_more_address)\ | |||||
| for y, c in zip(self.pred_res_parse(Y), C) | |||||
| self.abduce((z, prob, y), max_address_num, require_more_address)\ | |||||
| for z, prob, y in zip(Z['cls'], Z['prob'], Y) | |||||
| ] | ] | ||||
| def __call__(self, Y, C, max_address_num = -1, require_more_address = 0): | |||||
| return self.batch_abduce(Y, C, max_address_num, require_more_address) | |||||
| def __call__(self, Z, Y, max_address_num = -1, require_more_address = 0): | |||||
| return self.batch_abduce(Z, Y, max_address_num, require_more_address) | |||||
| if __name__ == "__main__": | |||||
| if __name__ == '__main__': | |||||
| kb = add_KB() | kb = add_KB() | ||||
| abd = AbducerBase(kb) | abd = AbducerBase(kb) | ||||
| res = abd.abduce(([1, 1, 1], 4), max_address_num = 2, require_more_address = 0) | res = abd.abduce(([1, 1, 1], 4), max_address_num = 2, require_more_address = 0) | ||||
| @@ -166,7 +161,7 @@ if __name__ == "__main__": | |||||
| abd = AbducerBase(kb) | abd = AbducerBase(kb) | ||||
| res = abd.abduce((['5', '+', '2'], 3), max_address_num = 2, require_more_address = 0) | res = abd.abduce((['5', '+', '2'], 3), max_address_num = 2, require_more_address = 0) | ||||
| print(res) | print(res) | ||||
| res = abd.abduce((['5', '+', '2'], 1.67), max_address_num = 2, require_more_address = 0) | |||||
| res = abd.abduce((['5', '+', '2'], 1.67), max_address_num = 3, require_more_address = 0) | |||||
| print(res) | print(res) | ||||
| res = abd.abduce((['5', '+', '3'], 0.33), max_address_num = 3, require_more_address = 3) | res = abd.abduce((['5', '+', '3'], 0.33), max_address_num = 3, require_more_address = 3) | ||||
| print(res) | print(res) | ||||