| @@ -17,7 +17,7 @@ import numpy as np | |||
| from collections import defaultdict | |||
| from itertools import product | |||
| from itertools import product, combinations | |||
| class KBBase(ABC): | |||
| def __init__(self): | |||
| @@ -80,20 +80,30 @@ class add_KB(KBBase): | |||
| def get_all_candidates(self): | |||
| return sum([sum(v.values(), []) for v in self.base.values()], []) | |||
| def get_abduce_candidates(self, pred_res, key, length, dist_func, max_address_num, require_more_address): | |||
| def get_abduce_candidates(self, pred_res, key, max_address_num, require_more_address): | |||
| if key is None: | |||
| return self.get_all_candidates() | |||
| candidates = [] | |||
| all_candidates = list(product(self.pseudo_label_list, repeat = len(pred_res))) | |||
| for address_num in range(length + 1): | |||
| for address_num in range(len(pred_res) + 1): | |||
| if(address_num > max_address_num): | |||
| print('No candidates found') | |||
| return None, None, None | |||
| for c in all_candidates: | |||
| if(dist_func(c, pred_res) == address_num): | |||
| if(self.logic_forward(c) == key): | |||
| candidates.append(c) | |||
| if(address_num == 0): | |||
| if(self.logic_forward(pred_res) == key): | |||
| candidates.append(pred_res) | |||
| else: | |||
| all_address_candidate = list(product(self.pseudo_label_list, repeat = address_num)) | |||
| address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) | |||
| for address_idx in address_idx_list: | |||
| for c in all_address_candidate: | |||
| pred_res_array = np.array(pred_res) | |||
| pred_res_array[np.array(address_idx)] = c | |||
| if(np.count_nonzero(np.array(c) != np.array(pred_res)[np.array(address_idx)]) == address_num and self.logic_forward(pred_res_array) == key): | |||
| candidates.append(pred_res_array) | |||
| if(len(candidates) > 0): | |||
| min_address_num = address_num | |||
| break | |||
| @@ -101,10 +111,14 @@ class add_KB(KBBase): | |||
| for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1): | |||
| if(address_num > max_address_num): | |||
| return candidates, min_address_num, address_num - 1 | |||
| for c in all_candidates: | |||
| if(dist_func(c, pred_res) == address_num): | |||
| if(self.logic_forward(c) == key): | |||
| candidates.append(c) | |||
| all_candidate = list(product(self.pseudo_label_list, repeat = address_num)) | |||
| address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) | |||
| for address_idx in address_idx_list: | |||
| for c in all_candidate: | |||
| pred_res_array = np.array(pred_res) | |||
| pred_res_array[np.array(address_idx)] = c | |||
| if(np.count_nonzero(np.array(c) != pred_res_array[np.array(address_idx)]) == address_num and self.logic_forward(pred_res_array) == key): | |||
| candidates.append(pred_res_array) | |||
| return candidates, min_address_num, address_num | |||
| @@ -114,6 +128,91 @@ class add_KB(KBBase): | |||
| def __len__(self): | |||
| return sum(self._dict_len(v) for v in self.base.values()) | |||
| # class hwf_KB(KBBase): | |||
| # def __init__(self, pseudo_label_list, max_len = 5): | |||
| # super().__init__() | |||
| # self.pseudo_label_list = pseudo_label_list | |||
| # self.base = {} | |||
| # X = self.get_X(self.pseudo_label_list, max_len) | |||
| # Y = self.get_Y(X, self.logic_forward) | |||
| # for x, y in zip(X, Y): | |||
| # self.base.setdefault(len(x), defaultdict(list))[y].append(np.array(x)) | |||
| # def logic_forward(self, nums): | |||
| # return sum(nums) | |||
| # def get_X(self, pseudo_label_list, max_len): | |||
| # res = [] | |||
| # assert(max_len >= 2) | |||
| # for len in range(2, max_len + 1): | |||
| # res += list(product(pseudo_label_list, repeat = len)) | |||
| # return res | |||
| # def get_Y(self, X, logic_forward): | |||
| # return [logic_forward(nums) for nums in X] | |||
| # def get_candidates(self, key, length = None): | |||
| # if key is None: | |||
| # return self.get_all_candidates() | |||
| # length = self._length(length) | |||
| # return sum([self.base[l][key] for l in length], []) | |||
| # def get_all_candidates(self): | |||
| # return sum([sum(v.values(), []) for v in self.base.values()], []) | |||
| # def get_abduce_candidates(self, pred_res, key, length, dist_func, max_address_num, require_more_address): | |||
| # if key is None: | |||
| # return self.get_all_candidates() | |||
| # candidates = [] | |||
| # # all_candidates = list(product(self.pseudo_label_list, repeat = len(pred_res))) | |||
| # for address_num in range(length + 1): | |||
| # if(address_num > max_address_num): | |||
| # print('No candidates found') | |||
| # return None, None, None | |||
| # if(address_num == 0): | |||
| # if(self.logic_forward(pred_res) == key): | |||
| # candidates.append(pred_res) | |||
| # else: | |||
| # all_address_candidate = list(product(self.pseudo_label_list, repeat = address_num)) | |||
| # address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) | |||
| # for address_idx in address_idx_list: | |||
| # for c in all_address_candidate: | |||
| # pred_res_array = np.array(pred_res) | |||
| # pred_res_array[np.array(address_idx)] = c | |||
| # if(np.count_nonzero(np.array(c) != np.array(pred_res)[np.array(address_idx)]) == address_num and self.logic_forward(pred_res_array) == key): | |||
| # candidates.append(pred_res_array) | |||
| # if(len(candidates) > 0): | |||
| # min_address_num = address_num | |||
| # break | |||
| # for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1): | |||
| # if(address_num > max_address_num): | |||
| # return candidates, min_address_num, address_num - 1 | |||
| # all_candidate = list(product(self.pseudo_label_list, repeat = address_num)) | |||
| # address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) | |||
| # for address_idx in address_idx_list: | |||
| # for c in all_candidate: | |||
| # pred_res_array = np.array(pred_res) | |||
| # pred_res_array[np.array(address_idx)] = c | |||
| # if(np.count_nonzero(np.array(c) != pred_res_array[np.array(address_idx)]) == address_num and self.logic_forward(pred_res_array) == key): | |||
| # candidates.append(pred_res_array) | |||
| # return candidates, min_address_num, address_num | |||
| # def _dict_len(self, dic): | |||
| # return sum(len(c) for c in dic.values()) | |||
| # def __len__(self): | |||
| # return sum(self._dict_len(v) for v in self.base.values()) | |||
| class cls_KB(KBBase): | |||
| def __init__(self, X, Y = None): | |||