diff --git a/abducer/abducer_base.py b/abducer/abducer_base.py index 576cce4..1168ea6 100644 --- a/abducer/abducer_base.py +++ b/abducer/abducer_base.py @@ -11,7 +11,8 @@ #================================================================# import abc -from kb import add_KB, hwf_KB +# from kb import add_KB, hwf_KB +from abducer.kb import add_KB, hwf_KB import numpy as np from itertools import product, combinations @@ -37,18 +38,19 @@ def confidence_dist(A, B): class AbducerBase(abc.ABC): - def __init__(self, kb, dist_func = "hamming", pred_res_parse = None, cache = True): + def __init__(self, kb, dist_func = 'hamming', pred_res_parse = None, cache = True): self.kb = kb - if dist_func == "hamming": - dist_func = hamming_dist - elif dist_func == "confidence": - dist_func = confidence_dist - self.dist_func = dist_func + + if(dist_func == 'hamming'): + self.dist_func = hamming_dist + elif(dist_func == 'confidence'): + self.dist_func = confidence_dist + if pred_res_parse is None: - if(dist_func == "hamming"): + if(dist_func == 'hamming'): pred_res_parse = lambda x : x["cls"] - elif dist_func == "confidence": - pred_res_parse = lambda x : x[" "] + elif dist_func == 'confidence': + pred_res_parse = lambda x : x["prob"] self.pred_res_parse = pred_res_parse self.cache = cache @@ -61,20 +63,24 @@ class AbducerBase(abc.ABC): idxs = np.where(cost_list == min_address_num)[0] return [candidates[idx] for idx in idxs][0] - def abduce(self, data, max_address_num = 3, require_more_address = 0, length = -1): + def abduce(self, data, max_address_num = -1, require_more_address = 0): pred_res, ans = data + pred_res = [self.kb.pseudo_label_list[sym] for sym in pred_res] + - if length == -1: - length = len(pred_res) + if max_address_num == -1: + max_address_num = len(pred_res) if(self.cache and (tuple(pred_res), ans) in self.cache_min_address_num): address_num = min(max_address_num, self.cache_min_address_num[(tuple(pred_res), ans)] + require_more_address) if((tuple(pred_res), ans, address_num) in self.cache_candidates): - print('cached') - return self.cache_candidates[(tuple(pred_res), ans, address_num)] + # print('cached') + candidates = self.cache_candidates[(tuple(pred_res), ans, address_num)] + candidates = self.get_min_cost_candidate(pred_res, candidates) + return candidates if(self.kb.base != {}): - all_candidates = self.kb.get_candidates(ans, length) + all_candidates = self.kb.get_candidates(ans, len(pred_res)) cost_list = self.dist_func(pred_res, all_candidates) min_address_num = np.min(cost_list) address_num = min(max_address_num, min_address_num + require_more_address) @@ -107,7 +113,6 @@ class AbducerBase(abc.ABC): def get_abduce_candidates(self, pred_res, key, max_address_num, require_more_address): candidates = [] - for address_num in range(len(pred_res) + 1): if(address_num > max_address_num): print('No candidates found') @@ -132,22 +137,20 @@ class AbducerBase(abc.ABC): return candidates, min_address_num, address_num - - - def batch_abduce(self, Y, C, max_address_num = 3, require_more_address = 0): + + def batch_abduce(self, Y, C, max_address_num = -1, require_more_address = 0): return [ self.abduce((y, c), max_address_num, require_more_address)\ for y, c in zip(self.pred_res_parse(Y), C) ] - def __call__(self, Y, C, max_address_num = 3, require_more_address = 0): + def __call__(self, Y, C, max_address_num = -1, require_more_address = 0): return self.batch_abduce(Y, C, max_address_num, require_more_address) if __name__ == "__main__": - pseudo_label_list = list(range(10)) - kb = add_KB(pseudo_label_list) + kb = add_KB() abd = AbducerBase(kb) res = abd.abduce(([1, 1, 1], 4), max_address_num = 2, require_more_address = 0) print(res) @@ -161,8 +164,7 @@ if __name__ == "__main__": print(res) print() - pseudo_label_list = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '0', '+', '-', '*', '/'] - kb = hwf_KB(pseudo_label_list) + kb = hwf_KB() abd = AbducerBase(kb) res = abd.abduce((['5', '+', '2'], 3), max_address_num = 2, require_more_address = 0) print(res)