From 40a3889c68dd4472007ea96d193eec52711098d6 Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Sun, 20 Nov 2022 22:01:50 +0800 Subject: [PATCH] Update abducer_base.py --- abducer/abducer_base.py | 51 ++++++++++++++++++++++------------------- 1 file changed, 27 insertions(+), 24 deletions(-) diff --git a/abducer/abducer_base.py b/abducer/abducer_base.py index c94bca4..bbf00fa 100644 --- a/abducer/abducer_base.py +++ b/abducer/abducer_base.py @@ -17,23 +17,7 @@ import numpy as np from itertools import product, combinations -def hamming_dist(A, B): - B = np.array(B) - A = np.expand_dims(A, axis = 0).repeat(axis=0, repeats=(len(B))) - return np.sum(A != B, axis = 1) -def confidence_dist(A, B): - B = np.array(B) - - #print(A) - A = np.clip(A, 1e-9, 1) - A = np.expand_dims(A, axis=0) - A = A.repeat(axis=0, repeats=(len(B))) - rows = np.array(range(len(B))) - rows = np.expand_dims(rows, axis = 1).repeat(axis = 1, repeats = len(B[0])) - cols = np.array(range(len(B[0]))) - cols = np.expand_dims(cols, axis = 0).repeat(axis = 0, repeats = len(B)) - return 1 - np.prod(A[rows, cols, B], axis = 1) @@ -46,11 +30,31 @@ class AbducerBase(abc.ABC): self.cache_min_address_num = {} self.cache_candidates = {} + def hamming_dist(self, A, B): + B = np.array(B) + A = np.expand_dims(A, axis = 0).repeat(axis=0, repeats=(len(B))) + return np.sum(A != B, axis = 1) + + def confidence_dist(self, A, B): + mapping = dict(zip(self.kb.pseudo_label_list, list(range(len(self.kb.pseudo_label_list))))) + B = [list(map(lambda x : mapping[x], b)) for b in B] + + B = np.array(B) + A = np.clip(A, 1e-9, 1) + A = np.expand_dims(A, axis=0) + A = A.repeat(axis=0, repeats=(len(B))) + rows = np.array(range(len(B))) + rows = np.expand_dims(rows, axis = 1).repeat(axis = 1, repeats = len(B[0])) + cols = np.array(range(len(B[0]))) + cols = np.expand_dims(cols, axis = 0).repeat(axis = 0, repeats = len(B)) + return 1 - np.prod(A[rows, cols, B], axis = 1) + + def get_cost_list(self, pred_res, pred_res_prob, candidates): if(self.dist_func == 'hamming'): - return hamming_dist(pred_res, candidates) + return self.hamming_dist(pred_res, candidates) elif(self.dist_func == 'confidence'): - return confidence_dist(pred_res_prob, candidates) + return self.confidence_dist(pred_res_prob, candidates) def get_min_cost_candidate(self, pred_res, pred_res_prob, candidates): cost_list = self.get_cost_list(pred_res, pred_res_prob, candidates) @@ -60,7 +64,6 @@ class AbducerBase(abc.ABC): def abduce(self, data, max_address_num = -1, require_more_address = 0): pred_res, pred_res_prob, ans = data - if max_address_num == -1: max_address_num = len(pred_res) @@ -69,12 +72,12 @@ class AbducerBase(abc.ABC): if((tuple(pred_res), ans, address_num) in self.cache_candidates): # print('cached') candidates = self.cache_candidates[(tuple(pred_res), ans, address_num)] - candidates = self.get_min_cost_candidate(pred_res, pred_res_prob, candidates) - return candidates + candidate = self.get_min_cost_candidate(pred_res, pred_res_prob, candidates) + return candidate if(self.kb.base != {}): all_candidates = self.kb.get_candidates(ans, len(pred_res)) - cost_list = self.get_cost_list(pred_res, pred_res_prob, all_candidates) + cost_list = self.hamming_dist(pred_res, all_candidates) min_address_num = np.min(cost_list) address_num = min(max_address_num, min_address_num + require_more_address) idxs = np.where(cost_list <= address_num)[0] @@ -87,8 +90,8 @@ class AbducerBase(abc.ABC): self.cache_min_address_num[(tuple(pred_res), ans)] = min_address_num self.cache_candidates[(tuple(pred_res), ans, address_num)] = candidates - candidates = self.get_min_cost_candidate(pred_res, pred_res_prob, candidates) - return candidates + candidate = self.get_min_cost_candidate(pred_res, pred_res_prob, candidates) + return candidate def address(self, address_num, pred_res, key): new_candidates = []