Browse Source

Update abducer_base.py

pull/3/head
troyyyyy GitHub 3 years ago
parent
commit
40a3889c68
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 27 additions and 24 deletions
  1. +27
    -24
      abducer/abducer_base.py

+ 27
- 24
abducer/abducer_base.py View File

@@ -17,23 +17,7 @@ import numpy as np

from itertools import product, combinations

def hamming_dist(A, B):
B = np.array(B)
A = np.expand_dims(A, axis = 0).repeat(axis=0, repeats=(len(B)))
return np.sum(A != B, axis = 1)

def confidence_dist(A, B):
B = np.array(B)

#print(A)
A = np.clip(A, 1e-9, 1)
A = np.expand_dims(A, axis=0)
A = A.repeat(axis=0, repeats=(len(B)))
rows = np.array(range(len(B)))
rows = np.expand_dims(rows, axis = 1).repeat(axis = 1, repeats = len(B[0]))
cols = np.array(range(len(B[0])))
cols = np.expand_dims(cols, axis = 0).repeat(axis = 0, repeats = len(B))
return 1 - np.prod(A[rows, cols, B], axis = 1)



@@ -46,11 +30,31 @@ class AbducerBase(abc.ABC):
self.cache_min_address_num = {}
self.cache_candidates = {}

def hamming_dist(self, A, B):
B = np.array(B)
A = np.expand_dims(A, axis = 0).repeat(axis=0, repeats=(len(B)))
return np.sum(A != B, axis = 1)

def confidence_dist(self, A, B):
mapping = dict(zip(self.kb.pseudo_label_list, list(range(len(self.kb.pseudo_label_list)))))
B = [list(map(lambda x : mapping[x], b)) for b in B]
B = np.array(B)
A = np.clip(A, 1e-9, 1)
A = np.expand_dims(A, axis=0)
A = A.repeat(axis=0, repeats=(len(B)))
rows = np.array(range(len(B)))
rows = np.expand_dims(rows, axis = 1).repeat(axis = 1, repeats = len(B[0]))
cols = np.array(range(len(B[0])))
cols = np.expand_dims(cols, axis = 0).repeat(axis = 0, repeats = len(B))
return 1 - np.prod(A[rows, cols, B], axis = 1)
def get_cost_list(self, pred_res, pred_res_prob, candidates):
if(self.dist_func == 'hamming'):
return hamming_dist(pred_res, candidates)
return self.hamming_dist(pred_res, candidates)
elif(self.dist_func == 'confidence'):
return confidence_dist(pred_res_prob, candidates)
return self.confidence_dist(pred_res_prob, candidates)

def get_min_cost_candidate(self, pred_res, pred_res_prob, candidates):
cost_list = self.get_cost_list(pred_res, pred_res_prob, candidates)
@@ -60,7 +64,6 @@ class AbducerBase(abc.ABC):

def abduce(self, data, max_address_num = -1, require_more_address = 0):
pred_res, pred_res_prob, ans = data
if max_address_num == -1:
max_address_num = len(pred_res)

@@ -69,12 +72,12 @@ class AbducerBase(abc.ABC):
if((tuple(pred_res), ans, address_num) in self.cache_candidates):
# print('cached')
candidates = self.cache_candidates[(tuple(pred_res), ans, address_num)]
candidates = self.get_min_cost_candidate(pred_res, pred_res_prob, candidates)
return candidates
candidate = self.get_min_cost_candidate(pred_res, pred_res_prob, candidates)
return candidate
if(self.kb.base != {}):
all_candidates = self.kb.get_candidates(ans, len(pred_res))
cost_list = self.get_cost_list(pred_res, pred_res_prob, all_candidates)
cost_list = self.hamming_dist(pred_res, all_candidates)
min_address_num = np.min(cost_list)
address_num = min(max_address_num, min_address_num + require_more_address)
idxs = np.where(cost_list <= address_num)[0]
@@ -87,8 +90,8 @@ class AbducerBase(abc.ABC):
self.cache_min_address_num[(tuple(pred_res), ans)] = min_address_num
self.cache_candidates[(tuple(pred_res), ans, address_num)] = candidates

candidates = self.get_min_cost_candidate(pred_res, pred_res_prob, candidates)
return candidates
candidate = self.get_min_cost_candidate(pred_res, pred_res_prob, candidates)
return candidate
def address(self, address_num, pred_res, key):
new_candidates = []


Loading…
Cancel
Save