From 7edd652eea43fb43824db61b7bf1929fc0b12779 Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Tue, 15 Nov 2022 21:23:15 +0800 Subject: [PATCH] Update abducer_base.py --- abducer/abducer_base.py | 77 +++++++++++++++++++++++++++++++++-------- 1 file changed, 62 insertions(+), 15 deletions(-) diff --git a/abducer/abducer_base.py b/abducer/abducer_base.py index f990013..b1c661b 100644 --- a/abducer/abducer_base.py +++ b/abducer/abducer_base.py @@ -14,15 +14,14 @@ import abc from kb import add_KB import numpy as np -def hamming_dist(A, B): - return np.sum(np.array(A) != np.array(B)) +from itertools import product, combinations -def hamming_dist_kb(A, B): +def hamming_dist(A, B): B = np.array(B) A = np.expand_dims(A, axis = 0).repeat(axis=0, repeats=(len(B))) return np.sum(A != B, axis = 1) -def confidence_dist_kb(A, B): +def confidence_dist(A, B): B = np.array(B) #print(A) @@ -41,7 +40,10 @@ class AbducerBase(abc.ABC): def __init__(self, kb, dist_func = "hamming", pred_res_parse = None, cache = True): self.kb = kb if dist_func == "hamming": - self.dist_func = hamming_dist + dist_func = hamming_dist + elif dist_func == "confidence": + dist_func = confidence_dist + self.dist_func = dist_func if pred_res_parse is None: pred_res_parse = lambda x : x["cls"] self.pred_res_parse = pred_res_parse @@ -50,6 +52,7 @@ class AbducerBase(abc.ABC): self.cache_min_address_num = {} self.cache_candidates = {} + def abduce(self, data, max_address_num = 3, require_more_address = 0, length = -1): pred_res, ans = data @@ -62,8 +65,16 @@ class AbducerBase(abc.ABC): print('cached') return self.cache_candidates[(tuple(pred_res), ans, address_num)] - - candidates, min_address_num, address_num = self.kb.get_abduce_candidates(pred_res, ans, length, self.dist_func, max_address_num, require_more_address) + if(self.kb.base != {}): + all_candidates = self.kb.get_candidates(ans, length) + cost_list = self.dist_func(pred_res, all_candidates) + min_address_num = np.min(cost_list) + address_num = min(max_address_num, min_address_num + require_more_address) + idxs = np.where(cost_list <= address_num)[0] + candidates = [all_candidates[idx] for idx in idxs] + + else: + candidates, min_address_num, address_num = self.get_abduce_candidates(pred_res, ans, max_address_num, require_more_address) if(self.cache): self.cache_min_address_num[(tuple(pred_res), ans)] = min_address_num @@ -71,21 +82,57 @@ class AbducerBase(abc.ABC): return candidates - # candidates = self.kb.get_candidates(ans, length) - # cost_list = self.dist_func(pred_res, candidates) - # address_num = np.min(cost_list) - # # threshold = min(address_num + require_more_address, max_address_num) - # idxs = np.where(cost_list <= address_num + require_more_address)[0] - - # return [candidates[idx] for idx in idxs], address_num # if len(idxs) > 1: # return None # return [candidates[idx] for idx in idxs] + + def address(self, address_num, pred_res, key): + new_candidates = [] + all_address_candidate = list(product(self.kb.pseudo_label_list, repeat = address_num)) + address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) + for address_idx in address_idx_list: + for c in all_address_candidate: + pred_res_array = np.array(pred_res) + if(np.count_nonzero(np.array(c) != pred_res_array[np.array(address_idx)]) == address_num): + pred_res_array[np.array(address_idx)] = c + if(self.kb.logic_forward(pred_res_array) == key): + new_candidates.append(pred_res_array) + return new_candidates, address_num + + def get_abduce_candidates(self, pred_res, key, max_address_num, require_more_address): + + candidates = [] + for address_num in range(len(pred_res) + 1): + if(address_num > max_address_num): + print('No candidates found') + return None, None, None + + if(address_num == 0): + if(self.kb.logic_forward(pred_res) == key): + candidates.append(pred_res) + else: + new_candidates, address_num = self.address(address_num, pred_res, key) + candidates += new_candidates + + if(len(candidates) > 0): + min_address_num = address_num + break + + for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1): + if(address_num > max_address_num): + return candidates, min_address_num, address_num - 1 + new_candidates, address_num = self.address(address_num, pred_res, key) + candidates += new_candidates + + return candidates, min_address_num, address_num + + + def batch_abduce(self, Y, C, max_address_num = 3, require_more_address = 0): return [ self.abduce((y, c), max_address_num, require_more_address)\ @@ -101,7 +148,7 @@ if __name__ == "__main__": pseudo_label_list = list(range(10)) kb = add_KB(pseudo_label_list) abd = AbducerBase(kb) - res = abd.abduce(([1, 1, 1], 4), max_address_num = 2, require_more_address = 0) + res = abd.abduce(([1, 1, 1], 4), max_address_num = 2, require_more_address = 1) print(res) print() res = abd.abduce(([1, 1, 1], 4), max_address_num = 2, require_more_address = 1)