| @@ -17,23 +17,7 @@ import numpy as np | |||
| from itertools import product, combinations | |||
| def hamming_dist(A, B): | |||
| B = np.array(B) | |||
| A = np.expand_dims(A, axis = 0).repeat(axis=0, repeats=(len(B))) | |||
| return np.sum(A != B, axis = 1) | |||
| def confidence_dist(A, B): | |||
| B = np.array(B) | |||
| #print(A) | |||
| A = np.clip(A, 1e-9, 1) | |||
| A = np.expand_dims(A, axis=0) | |||
| A = A.repeat(axis=0, repeats=(len(B))) | |||
| rows = np.array(range(len(B))) | |||
| rows = np.expand_dims(rows, axis = 1).repeat(axis = 1, repeats = len(B[0])) | |||
| cols = np.array(range(len(B[0]))) | |||
| cols = np.expand_dims(cols, axis = 0).repeat(axis = 0, repeats = len(B)) | |||
| return 1 - np.prod(A[rows, cols, B], axis = 1) | |||
| @@ -46,11 +30,31 @@ class AbducerBase(abc.ABC): | |||
| self.cache_min_address_num = {} | |||
| self.cache_candidates = {} | |||
| def hamming_dist(self, A, B): | |||
| B = np.array(B) | |||
| A = np.expand_dims(A, axis = 0).repeat(axis=0, repeats=(len(B))) | |||
| return np.sum(A != B, axis = 1) | |||
| def confidence_dist(self, A, B): | |||
| mapping = dict(zip(self.kb.pseudo_label_list, list(range(len(self.kb.pseudo_label_list))))) | |||
| B = [list(map(lambda x : mapping[x], b)) for b in B] | |||
| B = np.array(B) | |||
| A = np.clip(A, 1e-9, 1) | |||
| A = np.expand_dims(A, axis=0) | |||
| A = A.repeat(axis=0, repeats=(len(B))) | |||
| rows = np.array(range(len(B))) | |||
| rows = np.expand_dims(rows, axis = 1).repeat(axis = 1, repeats = len(B[0])) | |||
| cols = np.array(range(len(B[0]))) | |||
| cols = np.expand_dims(cols, axis = 0).repeat(axis = 0, repeats = len(B)) | |||
| return 1 - np.prod(A[rows, cols, B], axis = 1) | |||
| def get_cost_list(self, pred_res, pred_res_prob, candidates): | |||
| if(self.dist_func == 'hamming'): | |||
| return hamming_dist(pred_res, candidates) | |||
| return self.hamming_dist(pred_res, candidates) | |||
| elif(self.dist_func == 'confidence'): | |||
| return confidence_dist(pred_res_prob, candidates) | |||
| return self.confidence_dist(pred_res_prob, candidates) | |||
| def get_min_cost_candidate(self, pred_res, pred_res_prob, candidates): | |||
| cost_list = self.get_cost_list(pred_res, pred_res_prob, candidates) | |||
| @@ -60,7 +64,6 @@ class AbducerBase(abc.ABC): | |||
| def abduce(self, data, max_address_num = -1, require_more_address = 0): | |||
| pred_res, pred_res_prob, ans = data | |||
| if max_address_num == -1: | |||
| max_address_num = len(pred_res) | |||
| @@ -69,12 +72,12 @@ class AbducerBase(abc.ABC): | |||
| if((tuple(pred_res), ans, address_num) in self.cache_candidates): | |||
| # print('cached') | |||
| candidates = self.cache_candidates[(tuple(pred_res), ans, address_num)] | |||
| candidates = self.get_min_cost_candidate(pred_res, pred_res_prob, candidates) | |||
| return candidates | |||
| candidate = self.get_min_cost_candidate(pred_res, pred_res_prob, candidates) | |||
| return candidate | |||
| if(self.kb.base != {}): | |||
| all_candidates = self.kb.get_candidates(ans, len(pred_res)) | |||
| cost_list = self.get_cost_list(pred_res, pred_res_prob, all_candidates) | |||
| cost_list = self.hamming_dist(pred_res, all_candidates) | |||
| min_address_num = np.min(cost_list) | |||
| address_num = min(max_address_num, min_address_num + require_more_address) | |||
| idxs = np.where(cost_list <= address_num)[0] | |||
| @@ -87,8 +90,8 @@ class AbducerBase(abc.ABC): | |||
| self.cache_min_address_num[(tuple(pred_res), ans)] = min_address_num | |||
| self.cache_candidates[(tuple(pred_res), ans, address_num)] = candidates | |||
| candidates = self.get_min_cost_candidate(pred_res, pred_res_prob, candidates) | |||
| return candidates | |||
| candidate = self.get_min_cost_candidate(pred_res, pred_res_prob, candidates) | |||
| return candidate | |||
| def address(self, address_num, pred_res, key): | |||
| new_candidates = [] | |||