| @@ -14,15 +14,14 @@ import abc | |||
| from kb import add_KB | |||
| import numpy as np | |||
| def hamming_dist(A, B): | |||
| return np.sum(np.array(A) != np.array(B)) | |||
| from itertools import product, combinations | |||
| def hamming_dist_kb(A, B): | |||
| def hamming_dist(A, B): | |||
| B = np.array(B) | |||
| A = np.expand_dims(A, axis = 0).repeat(axis=0, repeats=(len(B))) | |||
| return np.sum(A != B, axis = 1) | |||
| def confidence_dist_kb(A, B): | |||
| def confidence_dist(A, B): | |||
| B = np.array(B) | |||
| #print(A) | |||
| @@ -41,7 +40,10 @@ class AbducerBase(abc.ABC): | |||
| def __init__(self, kb, dist_func = "hamming", pred_res_parse = None, cache = True): | |||
| self.kb = kb | |||
| if dist_func == "hamming": | |||
| self.dist_func = hamming_dist | |||
| dist_func = hamming_dist | |||
| elif dist_func == "confidence": | |||
| dist_func = confidence_dist | |||
| self.dist_func = dist_func | |||
| if pred_res_parse is None: | |||
| pred_res_parse = lambda x : x["cls"] | |||
| self.pred_res_parse = pred_res_parse | |||
| @@ -50,6 +52,7 @@ class AbducerBase(abc.ABC): | |||
| self.cache_min_address_num = {} | |||
| self.cache_candidates = {} | |||
| def abduce(self, data, max_address_num = 3, require_more_address = 0, length = -1): | |||
| pred_res, ans = data | |||
| @@ -62,8 +65,16 @@ class AbducerBase(abc.ABC): | |||
| print('cached') | |||
| return self.cache_candidates[(tuple(pred_res), ans, address_num)] | |||
| candidates, min_address_num, address_num = self.kb.get_abduce_candidates(pred_res, ans, length, self.dist_func, max_address_num, require_more_address) | |||
| if(self.kb.base != {}): | |||
| all_candidates = self.kb.get_candidates(ans, length) | |||
| cost_list = self.dist_func(pred_res, all_candidates) | |||
| min_address_num = np.min(cost_list) | |||
| address_num = min(max_address_num, min_address_num + require_more_address) | |||
| idxs = np.where(cost_list <= address_num)[0] | |||
| candidates = [all_candidates[idx] for idx in idxs] | |||
| else: | |||
| candidates, min_address_num, address_num = self.get_abduce_candidates(pred_res, ans, max_address_num, require_more_address) | |||
| if(self.cache): | |||
| self.cache_min_address_num[(tuple(pred_res), ans)] = min_address_num | |||
| @@ -71,21 +82,57 @@ class AbducerBase(abc.ABC): | |||
| return candidates | |||
| # candidates = self.kb.get_candidates(ans, length) | |||
| # cost_list = self.dist_func(pred_res, candidates) | |||
| # address_num = np.min(cost_list) | |||
| # # threshold = min(address_num + require_more_address, max_address_num) | |||
| # idxs = np.where(cost_list <= address_num + require_more_address)[0] | |||
| # return [candidates[idx] for idx in idxs], address_num | |||
| # if len(idxs) > 1: | |||
| # return None | |||
| # return [candidates[idx] for idx in idxs] | |||
| def address(self, address_num, pred_res, key): | |||
| new_candidates = [] | |||
| all_address_candidate = list(product(self.kb.pseudo_label_list, repeat = address_num)) | |||
| address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) | |||
| for address_idx in address_idx_list: | |||
| for c in all_address_candidate: | |||
| pred_res_array = np.array(pred_res) | |||
| if(np.count_nonzero(np.array(c) != pred_res_array[np.array(address_idx)]) == address_num): | |||
| pred_res_array[np.array(address_idx)] = c | |||
| if(self.kb.logic_forward(pred_res_array) == key): | |||
| new_candidates.append(pred_res_array) | |||
| return new_candidates, address_num | |||
| def get_abduce_candidates(self, pred_res, key, max_address_num, require_more_address): | |||
| candidates = [] | |||
| for address_num in range(len(pred_res) + 1): | |||
| if(address_num > max_address_num): | |||
| print('No candidates found') | |||
| return None, None, None | |||
| if(address_num == 0): | |||
| if(self.kb.logic_forward(pred_res) == key): | |||
| candidates.append(pred_res) | |||
| else: | |||
| new_candidates, address_num = self.address(address_num, pred_res, key) | |||
| candidates += new_candidates | |||
| if(len(candidates) > 0): | |||
| min_address_num = address_num | |||
| break | |||
| for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1): | |||
| if(address_num > max_address_num): | |||
| return candidates, min_address_num, address_num - 1 | |||
| new_candidates, address_num = self.address(address_num, pred_res, key) | |||
| candidates += new_candidates | |||
| return candidates, min_address_num, address_num | |||
| def batch_abduce(self, Y, C, max_address_num = 3, require_more_address = 0): | |||
| return [ | |||
| self.abduce((y, c), max_address_num, require_more_address)\ | |||
| @@ -101,7 +148,7 @@ if __name__ == "__main__": | |||
| pseudo_label_list = list(range(10)) | |||
| kb = add_KB(pseudo_label_list) | |||
| abd = AbducerBase(kb) | |||
| res = abd.abduce(([1, 1, 1], 4), max_address_num = 2, require_more_address = 0) | |||
| res = abd.abduce(([1, 1, 1], 4), max_address_num = 2, require_more_address = 1) | |||
| print(res) | |||
| print() | |||
| res = abd.abduce(([1, 1, 1], 4), max_address_num = 2, require_more_address = 1) | |||