From 5beb068ee1c3d9f902fcefc27a1dcf93402b6b28 Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Tue, 15 Nov 2022 21:23:00 +0800 Subject: [PATCH] Update kb.py --- abducer/kb.py | 156 +++++++++++++++----------------------------------- 1 file changed, 46 insertions(+), 110 deletions(-) diff --git a/abducer/kb.py b/abducer/kb.py index b60f50e..95e9692 100644 --- a/abducer/kb.py +++ b/abducer/kb.py @@ -16,8 +16,7 @@ import copy import numpy as np from collections import defaultdict - -from itertools import product, combinations +from itertools import product class KBBase(ABC): def __init__(self): @@ -46,16 +45,17 @@ class KBBase(ABC): pass class add_KB(KBBase): - def __init__(self, pseudo_label_list, max_len = 5): + def __init__(self, pseudo_label_list, kb_max_len = -1): super().__init__() self.pseudo_label_list = pseudo_label_list self.base = {} - - X = self.get_X(self.pseudo_label_list, max_len) - Y = self.get_Y(X, self.logic_forward) + self.kb_max_len = kb_max_len + if(self.kb_max_len > 0): + X = self.get_X(self.pseudo_label_list, self.kb_max_len) + Y = self.get_Y(X, self.logic_forward) - for x, y in zip(X, Y): - self.base.setdefault(len(x), defaultdict(list))[y].append(np.array(x)) + for x, y in zip(X, Y): + self.base.setdefault(len(x), defaultdict(list))[y].append(np.array(x)) def logic_forward(self, nums): return sum(nums) @@ -71,60 +71,20 @@ class add_KB(KBBase): return [logic_forward(nums) for nums in X] def get_candidates(self, key, length = None): + if(self.base == {}): + return [] + if key is None: return self.get_all_candidates() length = self._length(length) + if(self.kb_max_len < min(length)): + return [] return sum([self.base[l][key] for l in length], []) def get_all_candidates(self): return sum([sum(v.values(), []) for v in self.base.values()], []) - def get_abduce_candidates(self, pred_res, key, max_address_num, require_more_address): - if key is None: - return self.get_all_candidates() - - candidates = [] - - for address_num in range(len(pred_res) + 1): - if(address_num > max_address_num): - print('No candidates found') - return None, None, None - - if(address_num == 0): - if(self.logic_forward(pred_res) == key): - candidates.append(pred_res) - else: - all_address_candidate = list(product(self.pseudo_label_list, repeat = address_num)) - address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) - for address_idx in address_idx_list: - for c in all_address_candidate: - pred_res_array = np.array(pred_res) - if(np.count_nonzero(np.array(c) != pred_res_array[np.array(address_idx)]) == address_num): - pred_res_array[np.array(address_idx)] = c - if(self.logic_forward(pred_res_array) == key): - candidates.append(pred_res_array) - - if(len(candidates) > 0): - min_address_num = address_num - break - - for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1): - if(address_num > max_address_num): - return candidates, min_address_num, address_num - 1 - all_candidate = list(product(self.pseudo_label_list, repeat = address_num)) - address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) - for address_idx in address_idx_list: - for c in all_candidate: - pred_res_array = np.array(pred_res) - if(np.count_nonzero(np.array(c) != pred_res_array[np.array(address_idx)]) == address_num): - pred_res_array[np.array(address_idx)] = c - if(self.logic_forward(pred_res_array) == key): - candidates.append(pred_res_array) - - return candidates, min_address_num, address_num - - def _dict_len(self, dic): return sum(len(c) for c in dic.values()) @@ -132,16 +92,17 @@ class add_KB(KBBase): return sum(self._dict_len(v) for v in self.base.values()) # class hwf_KB(KBBase): -# def __init__(self, pseudo_label_list, max_len = 5): +# def __init__(self, pseudo_label_list, kb_max_len = -1): # super().__init__() # self.pseudo_label_list = pseudo_label_list # self.base = {} - -# X = self.get_X(self.pseudo_label_list, max_len) -# Y = self.get_Y(X, self.logic_forward) +# self.kb_max_len = kb_max_len +# if(self.kb_max_len > 0): +# X = self.get_X(self.pseudo_label_list, self.kb_max_len) +# Y = self.get_Y(X, self.logic_forward) -# for x, y in zip(X, Y): -# self.base.setdefault(len(x), defaultdict(list))[y].append(np.array(x)) +# for x, y in zip(X, Y): +# self.base.setdefault(len(x), defaultdict(list))[y].append(np.array(x)) # def logic_forward(self, nums): # return sum(nums) @@ -157,65 +118,26 @@ class add_KB(KBBase): # return [logic_forward(nums) for nums in X] # def get_candidates(self, key, length = None): +# if(self.base == {}): +# return [] + # if key is None: # return self.get_all_candidates() # length = self._length(length) +# if(self.kb_max_len < min(length)): +# return [] # return sum([self.base[l][key] for l in length], []) # def get_all_candidates(self): # return sum([sum(v.values(), []) for v in self.base.values()], []) - -# def get_abduce_candidates(self, pred_res, key, length, dist_func, max_address_num, require_more_address): -# if key is None: -# return self.get_all_candidates() - -# candidates = [] -# # all_candidates = list(product(self.pseudo_label_list, repeat = len(pred_res))) - -# for address_num in range(length + 1): -# if(address_num > max_address_num): -# print('No candidates found') -# return None, None, None - -# if(address_num == 0): -# if(self.logic_forward(pred_res) == key): -# candidates.append(pred_res) -# else: -# all_address_candidate = list(product(self.pseudo_label_list, repeat = address_num)) -# address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) -# for address_idx in address_idx_list: -# for c in all_address_candidate: -# pred_res_array = np.array(pred_res) -# pred_res_array[np.array(address_idx)] = c -# if(np.count_nonzero(np.array(c) != np.array(pred_res)[np.array(address_idx)]) == address_num and self.logic_forward(pred_res_array) == key): -# candidates.append(pred_res_array) - -# if(len(candidates) > 0): -# min_address_num = address_num -# break - -# for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1): -# if(address_num > max_address_num): -# return candidates, min_address_num, address_num - 1 -# all_candidate = list(product(self.pseudo_label_list, repeat = address_num)) -# address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) -# for address_idx in address_idx_list: -# for c in all_candidate: -# pred_res_array = np.array(pred_res) -# pred_res_array[np.array(address_idx)] = c -# if(np.count_nonzero(np.array(c) != pred_res_array[np.array(address_idx)]) == address_num and self.logic_forward(pred_res_array) == key): -# candidates.append(pred_res_array) - -# return candidates, min_address_num, address_num - # def _dict_len(self, dic): # return sum(len(c) for c in dic.values()) # def __len__(self): # return sum(self._dict_len(v) for v in self.base.values()) - + class cls_KB(KBBase): def __init__(self, X, Y = None): super().__init__() @@ -298,25 +220,39 @@ class reg_KB(KBBase): return sum([sum(len(x) for x in D[0]) for D in self.base.values()]) if __name__ == "__main__": + # With ground KB pseudo_label_list = list(range(10)) - kb = add_KB(pseudo_label_list, max_len = 5) + kb = add_KB(pseudo_label_list, kb_max_len = 5) print('len(kb):', len(kb)) - print() res = kb.get_candidates(0) print(res) - print() res = kb.get_candidates(18, length = 2) print(res) - print() + res = kb.get_candidates(18, length = 8) + print(res) res = kb.get_candidates(7, length = 3) print(res) print() - pseudo_label_list = list(range(10)) + ['+', '-', '*', '/'] - kb = hwf_KB(pseudo_label_list, max_len = 5) + # Without ground KB + pseudo_label_list = list(range(10)) + kb = add_KB(pseudo_label_list) print('len(kb):', len(kb)) + res = kb.get_candidates(0) + print(res) + res = kb.get_candidates(18, length = 2) + print(res) + res = kb.get_candidates(18, length = 8) + print(res) + res = kb.get_candidates(7, length = 3) + print(res) print() + # pseudo_label_list = list(range(10)) + ['+', '-', '*', '/'] + # kb = hwf_KB(pseudo_label_list, max_len = 5) + # print('len(kb):', len(kb)) + # print() + X = ["1+1", "0+1", "1+0", "2+0", "1+0+1"] Y = [2, 1, 1, 2, 2]