From cf33651a7ade6cf1f68698a7ece95cc96ce72cc3 Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Sat, 19 Nov 2022 10:25:51 +0800 Subject: [PATCH] Update abducer_base.py --- abducer/abducer_base.py | 51 +++++++++++++++++++---------------------- 1 file changed, 23 insertions(+), 28 deletions(-) diff --git a/abducer/abducer_base.py b/abducer/abducer_base.py index 23019b8..af9fdc8 100644 --- a/abducer/abducer_base.py +++ b/abducer/abducer_base.py @@ -38,34 +38,29 @@ def confidence_dist(A, B): class AbducerBase(abc.ABC): - def __init__(self, kb, dist_func = 'hamming', pred_res_parse = None, cache = True): + def __init__(self, kb, dist_func = 'confidence', cache = True): self.kb = kb - - if(dist_func == 'hamming'): - self.dist_func = hamming_dist - elif(dist_func == 'confidence'): - self.dist_func = confidence_dist - - if pred_res_parse is None: - if(dist_func == 'hamming'): - pred_res_parse = lambda x : x["cls"] - elif dist_func == 'confidence': - pred_res_parse = lambda x : x["prob"] - self.pred_res_parse = pred_res_parse - + assert(dist_func == 'hamming' or dist_func == 'confidence') + self.dist_func = dist_func self.cache = cache self.cache_min_address_num = {} self.cache_candidates = {} - def get_min_cost_candidate(self, pred_res, candidates): - cost_list = self.dist_func(pred_res, candidates) + def get_cost_list(self, pred_res, pred_res_prob, candidates): + if(self.dist_func == 'hamming'): + return hamming_dist(pred_res, candidates) + elif(self.dist_func == 'confidence'): + return confidence_dist(pred_res_prob, candidates) + + def get_min_cost_candidate(self, pred_res, pred_res_prob, candidates): + cost_list = self.get_cost_list(pred_res, pred_res_prob, candidates) min_address_num = np.min(cost_list) idxs = np.where(cost_list == min_address_num)[0] return [candidates[idx] for idx in idxs][0] def abduce(self, data, max_address_num = -1, require_more_address = 0): - pred_res, ans = data - + pred_res, pred_res_prob, ans = data + if max_address_num == -1: max_address_num = len(pred_res) @@ -74,12 +69,12 @@ class AbducerBase(abc.ABC): if((tuple(pred_res), ans, address_num) in self.cache_candidates): # print('cached') candidates = self.cache_candidates[(tuple(pred_res), ans, address_num)] - candidates = self.get_min_cost_candidate(pred_res, candidates) + candidates = self.get_min_cost_candidate(pred_res, pred_res_prob, candidates) return candidates if(self.kb.base != {}): all_candidates = self.kb.get_candidates(ans, len(pred_res)) - cost_list = self.dist_func(pred_res, all_candidates) + cost_list = self.get_cost_list(pred_res, pred_res_prob, all_candidates) min_address_num = np.min(cost_list) address_num = min(max_address_num, min_address_num + require_more_address) idxs = np.where(cost_list <= address_num)[0] @@ -92,7 +87,7 @@ class AbducerBase(abc.ABC): self.cache_min_address_num[(tuple(pred_res), ans)] = min_address_num self.cache_candidates[(tuple(pred_res), ans, address_num)] = candidates - candidates = self.get_min_cost_candidate(pred_res, candidates) + candidates = self.get_min_cost_candidate(pred_res, pred_res_prob, candidates) return candidates def address(self, address_num, pred_res, key): @@ -136,18 +131,18 @@ class AbducerBase(abc.ABC): return candidates, min_address_num, address_num - def batch_abduce(self, Y, C, max_address_num = -1, require_more_address = 0): + def batch_abduce(self, Z, Y, max_address_num = -1, require_more_address = 0): return [ - self.abduce((y, c), max_address_num, require_more_address)\ - for y, c in zip(self.pred_res_parse(Y), C) + self.abduce((z, prob, y), max_address_num, require_more_address)\ + for z, prob, y in zip(Z['cls'], Z['prob'], Y) ] - def __call__(self, Y, C, max_address_num = -1, require_more_address = 0): - return self.batch_abduce(Y, C, max_address_num, require_more_address) + def __call__(self, Z, Y, max_address_num = -1, require_more_address = 0): + return self.batch_abduce(Z, Y, max_address_num, require_more_address) -if __name__ == "__main__": +if __name__ == '__main__': kb = add_KB() abd = AbducerBase(kb) res = abd.abduce(([1, 1, 1], 4), max_address_num = 2, require_more_address = 0) @@ -166,7 +161,7 @@ if __name__ == "__main__": abd = AbducerBase(kb) res = abd.abduce((['5', '+', '2'], 3), max_address_num = 2, require_more_address = 0) print(res) - res = abd.abduce((['5', '+', '2'], 1.67), max_address_num = 2, require_more_address = 0) + res = abd.abduce((['5', '+', '2'], 1.67), max_address_num = 3, require_more_address = 0) print(res) res = abd.abduce((['5', '+', '3'], 0.33), max_address_num = 3, require_more_address = 3) print(res)