Browse Source

Update abducer_base.py

pull/3/head
troyyyyy GitHub 3 years ago
parent
commit
cf33651a7a
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 23 additions and 28 deletions
  1. +23
    -28
      abducer/abducer_base.py

+ 23
- 28
abducer/abducer_base.py View File

@@ -38,34 +38,29 @@ def confidence_dist(A, B):


class AbducerBase(abc.ABC):
def __init__(self, kb, dist_func = 'hamming', pred_res_parse = None, cache = True):
def __init__(self, kb, dist_func = 'confidence', cache = True):
self.kb = kb
if(dist_func == 'hamming'):
self.dist_func = hamming_dist
elif(dist_func == 'confidence'):
self.dist_func = confidence_dist
if pred_res_parse is None:
if(dist_func == 'hamming'):
pred_res_parse = lambda x : x["cls"]
elif dist_func == 'confidence':
pred_res_parse = lambda x : x["prob"]
self.pred_res_parse = pred_res_parse
assert(dist_func == 'hamming' or dist_func == 'confidence')
self.dist_func = dist_func
self.cache = cache
self.cache_min_address_num = {}
self.cache_candidates = {}

def get_min_cost_candidate(self, pred_res, candidates):
cost_list = self.dist_func(pred_res, candidates)
def get_cost_list(self, pred_res, pred_res_prob, candidates):
if(self.dist_func == 'hamming'):
return hamming_dist(pred_res, candidates)
elif(self.dist_func == 'confidence'):
return confidence_dist(pred_res_prob, candidates)

def get_min_cost_candidate(self, pred_res, pred_res_prob, candidates):
cost_list = self.get_cost_list(pred_res, pred_res_prob, candidates)
min_address_num = np.min(cost_list)
idxs = np.where(cost_list == min_address_num)[0]
return [candidates[idx] for idx in idxs][0]

def abduce(self, data, max_address_num = -1, require_more_address = 0):
pred_res, ans = data
pred_res, pred_res_prob, ans = data
if max_address_num == -1:
max_address_num = len(pred_res)

@@ -74,12 +69,12 @@ class AbducerBase(abc.ABC):
if((tuple(pred_res), ans, address_num) in self.cache_candidates):
# print('cached')
candidates = self.cache_candidates[(tuple(pred_res), ans, address_num)]
candidates = self.get_min_cost_candidate(pred_res, candidates)
candidates = self.get_min_cost_candidate(pred_res, pred_res_prob, candidates)
return candidates
if(self.kb.base != {}):
all_candidates = self.kb.get_candidates(ans, len(pred_res))
cost_list = self.dist_func(pred_res, all_candidates)
cost_list = self.get_cost_list(pred_res, pred_res_prob, all_candidates)
min_address_num = np.min(cost_list)
address_num = min(max_address_num, min_address_num + require_more_address)
idxs = np.where(cost_list <= address_num)[0]
@@ -92,7 +87,7 @@ class AbducerBase(abc.ABC):
self.cache_min_address_num[(tuple(pred_res), ans)] = min_address_num
self.cache_candidates[(tuple(pred_res), ans, address_num)] = candidates

candidates = self.get_min_cost_candidate(pred_res, candidates)
candidates = self.get_min_cost_candidate(pred_res, pred_res_prob, candidates)
return candidates
def address(self, address_num, pred_res, key):
@@ -136,18 +131,18 @@ class AbducerBase(abc.ABC):
return candidates, min_address_num, address_num
def batch_abduce(self, Y, C, max_address_num = -1, require_more_address = 0):
def batch_abduce(self, Z, Y, max_address_num = -1, require_more_address = 0):
return [
self.abduce((y, c), max_address_num, require_more_address)\
for y, c in zip(self.pred_res_parse(Y), C)
self.abduce((z, prob, y), max_address_num, require_more_address)\
for z, prob, y in zip(Z['cls'], Z['prob'], Y)
]

def __call__(self, Y, C, max_address_num = -1, require_more_address = 0):
return self.batch_abduce(Y, C, max_address_num, require_more_address)
def __call__(self, Z, Y, max_address_num = -1, require_more_address = 0):
return self.batch_abduce(Z, Y, max_address_num, require_more_address)



if __name__ == "__main__":
if __name__ == '__main__':
kb = add_KB()
abd = AbducerBase(kb)
res = abd.abduce(([1, 1, 1], 4), max_address_num = 2, require_more_address = 0)
@@ -166,7 +161,7 @@ if __name__ == "__main__":
abd = AbducerBase(kb)
res = abd.abduce((['5', '+', '2'], 3), max_address_num = 2, require_more_address = 0)
print(res)
res = abd.abduce((['5', '+', '2'], 1.67), max_address_num = 2, require_more_address = 0)
res = abd.abduce((['5', '+', '2'], 1.67), max_address_num = 3, require_more_address = 0)
print(res)
res = abd.abduce((['5', '+', '3'], 0.33), max_address_num = 3, require_more_address = 3)
print(res)


Loading…
Cancel
Save