| @@ -16,8 +16,7 @@ import copy | |||
| import numpy as np | |||
| from collections import defaultdict | |||
| from itertools import product, combinations | |||
| from itertools import product | |||
| class KBBase(ABC): | |||
| def __init__(self): | |||
| @@ -46,16 +45,17 @@ class KBBase(ABC): | |||
| pass | |||
| class add_KB(KBBase): | |||
| def __init__(self, pseudo_label_list, max_len = 5): | |||
| def __init__(self, pseudo_label_list, kb_max_len = -1): | |||
| super().__init__() | |||
| self.pseudo_label_list = pseudo_label_list | |||
| self.base = {} | |||
| X = self.get_X(self.pseudo_label_list, max_len) | |||
| Y = self.get_Y(X, self.logic_forward) | |||
| self.kb_max_len = kb_max_len | |||
| if(self.kb_max_len > 0): | |||
| X = self.get_X(self.pseudo_label_list, self.kb_max_len) | |||
| Y = self.get_Y(X, self.logic_forward) | |||
| for x, y in zip(X, Y): | |||
| self.base.setdefault(len(x), defaultdict(list))[y].append(np.array(x)) | |||
| for x, y in zip(X, Y): | |||
| self.base.setdefault(len(x), defaultdict(list))[y].append(np.array(x)) | |||
| def logic_forward(self, nums): | |||
| return sum(nums) | |||
| @@ -71,60 +71,20 @@ class add_KB(KBBase): | |||
| return [logic_forward(nums) for nums in X] | |||
| def get_candidates(self, key, length = None): | |||
| if(self.base == {}): | |||
| return [] | |||
| if key is None: | |||
| return self.get_all_candidates() | |||
| length = self._length(length) | |||
| if(self.kb_max_len < min(length)): | |||
| return [] | |||
| return sum([self.base[l][key] for l in length], []) | |||
| def get_all_candidates(self): | |||
| return sum([sum(v.values(), []) for v in self.base.values()], []) | |||
| def get_abduce_candidates(self, pred_res, key, max_address_num, require_more_address): | |||
| if key is None: | |||
| return self.get_all_candidates() | |||
| candidates = [] | |||
| for address_num in range(len(pred_res) + 1): | |||
| if(address_num > max_address_num): | |||
| print('No candidates found') | |||
| return None, None, None | |||
| if(address_num == 0): | |||
| if(self.logic_forward(pred_res) == key): | |||
| candidates.append(pred_res) | |||
| else: | |||
| all_address_candidate = list(product(self.pseudo_label_list, repeat = address_num)) | |||
| address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) | |||
| for address_idx in address_idx_list: | |||
| for c in all_address_candidate: | |||
| pred_res_array = np.array(pred_res) | |||
| if(np.count_nonzero(np.array(c) != pred_res_array[np.array(address_idx)]) == address_num): | |||
| pred_res_array[np.array(address_idx)] = c | |||
| if(self.logic_forward(pred_res_array) == key): | |||
| candidates.append(pred_res_array) | |||
| if(len(candidates) > 0): | |||
| min_address_num = address_num | |||
| break | |||
| for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1): | |||
| if(address_num > max_address_num): | |||
| return candidates, min_address_num, address_num - 1 | |||
| all_candidate = list(product(self.pseudo_label_list, repeat = address_num)) | |||
| address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) | |||
| for address_idx in address_idx_list: | |||
| for c in all_candidate: | |||
| pred_res_array = np.array(pred_res) | |||
| if(np.count_nonzero(np.array(c) != pred_res_array[np.array(address_idx)]) == address_num): | |||
| pred_res_array[np.array(address_idx)] = c | |||
| if(self.logic_forward(pred_res_array) == key): | |||
| candidates.append(pred_res_array) | |||
| return candidates, min_address_num, address_num | |||
| def _dict_len(self, dic): | |||
| return sum(len(c) for c in dic.values()) | |||
| @@ -132,16 +92,17 @@ class add_KB(KBBase): | |||
| return sum(self._dict_len(v) for v in self.base.values()) | |||
| # class hwf_KB(KBBase): | |||
| # def __init__(self, pseudo_label_list, max_len = 5): | |||
| # def __init__(self, pseudo_label_list, kb_max_len = -1): | |||
| # super().__init__() | |||
| # self.pseudo_label_list = pseudo_label_list | |||
| # self.base = {} | |||
| # X = self.get_X(self.pseudo_label_list, max_len) | |||
| # Y = self.get_Y(X, self.logic_forward) | |||
| # self.kb_max_len = kb_max_len | |||
| # if(self.kb_max_len > 0): | |||
| # X = self.get_X(self.pseudo_label_list, self.kb_max_len) | |||
| # Y = self.get_Y(X, self.logic_forward) | |||
| # for x, y in zip(X, Y): | |||
| # self.base.setdefault(len(x), defaultdict(list))[y].append(np.array(x)) | |||
| # for x, y in zip(X, Y): | |||
| # self.base.setdefault(len(x), defaultdict(list))[y].append(np.array(x)) | |||
| # def logic_forward(self, nums): | |||
| # return sum(nums) | |||
| @@ -157,65 +118,26 @@ class add_KB(KBBase): | |||
| # return [logic_forward(nums) for nums in X] | |||
| # def get_candidates(self, key, length = None): | |||
| # if(self.base == {}): | |||
| # return [] | |||
| # if key is None: | |||
| # return self.get_all_candidates() | |||
| # length = self._length(length) | |||
| # if(self.kb_max_len < min(length)): | |||
| # return [] | |||
| # return sum([self.base[l][key] for l in length], []) | |||
| # def get_all_candidates(self): | |||
| # return sum([sum(v.values(), []) for v in self.base.values()], []) | |||
| # def get_abduce_candidates(self, pred_res, key, length, dist_func, max_address_num, require_more_address): | |||
| # if key is None: | |||
| # return self.get_all_candidates() | |||
| # candidates = [] | |||
| # # all_candidates = list(product(self.pseudo_label_list, repeat = len(pred_res))) | |||
| # for address_num in range(length + 1): | |||
| # if(address_num > max_address_num): | |||
| # print('No candidates found') | |||
| # return None, None, None | |||
| # if(address_num == 0): | |||
| # if(self.logic_forward(pred_res) == key): | |||
| # candidates.append(pred_res) | |||
| # else: | |||
| # all_address_candidate = list(product(self.pseudo_label_list, repeat = address_num)) | |||
| # address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) | |||
| # for address_idx in address_idx_list: | |||
| # for c in all_address_candidate: | |||
| # pred_res_array = np.array(pred_res) | |||
| # pred_res_array[np.array(address_idx)] = c | |||
| # if(np.count_nonzero(np.array(c) != np.array(pred_res)[np.array(address_idx)]) == address_num and self.logic_forward(pred_res_array) == key): | |||
| # candidates.append(pred_res_array) | |||
| # if(len(candidates) > 0): | |||
| # min_address_num = address_num | |||
| # break | |||
| # for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1): | |||
| # if(address_num > max_address_num): | |||
| # return candidates, min_address_num, address_num - 1 | |||
| # all_candidate = list(product(self.pseudo_label_list, repeat = address_num)) | |||
| # address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) | |||
| # for address_idx in address_idx_list: | |||
| # for c in all_candidate: | |||
| # pred_res_array = np.array(pred_res) | |||
| # pred_res_array[np.array(address_idx)] = c | |||
| # if(np.count_nonzero(np.array(c) != pred_res_array[np.array(address_idx)]) == address_num and self.logic_forward(pred_res_array) == key): | |||
| # candidates.append(pred_res_array) | |||
| # return candidates, min_address_num, address_num | |||
| # def _dict_len(self, dic): | |||
| # return sum(len(c) for c in dic.values()) | |||
| # def __len__(self): | |||
| # return sum(self._dict_len(v) for v in self.base.values()) | |||
| class cls_KB(KBBase): | |||
| def __init__(self, X, Y = None): | |||
| super().__init__() | |||
| @@ -298,25 +220,39 @@ class reg_KB(KBBase): | |||
| return sum([sum(len(x) for x in D[0]) for D in self.base.values()]) | |||
| if __name__ == "__main__": | |||
| # With ground KB | |||
| pseudo_label_list = list(range(10)) | |||
| kb = add_KB(pseudo_label_list, max_len = 5) | |||
| kb = add_KB(pseudo_label_list, kb_max_len = 5) | |||
| print('len(kb):', len(kb)) | |||
| print() | |||
| res = kb.get_candidates(0) | |||
| print(res) | |||
| print() | |||
| res = kb.get_candidates(18, length = 2) | |||
| print(res) | |||
| print() | |||
| res = kb.get_candidates(18, length = 8) | |||
| print(res) | |||
| res = kb.get_candidates(7, length = 3) | |||
| print(res) | |||
| print() | |||
| pseudo_label_list = list(range(10)) + ['+', '-', '*', '/'] | |||
| kb = hwf_KB(pseudo_label_list, max_len = 5) | |||
| # Without ground KB | |||
| pseudo_label_list = list(range(10)) | |||
| kb = add_KB(pseudo_label_list) | |||
| print('len(kb):', len(kb)) | |||
| res = kb.get_candidates(0) | |||
| print(res) | |||
| res = kb.get_candidates(18, length = 2) | |||
| print(res) | |||
| res = kb.get_candidates(18, length = 8) | |||
| print(res) | |||
| res = kb.get_candidates(7, length = 3) | |||
| print(res) | |||
| print() | |||
| # pseudo_label_list = list(range(10)) + ['+', '-', '*', '/'] | |||
| # kb = hwf_KB(pseudo_label_list, max_len = 5) | |||
| # print('len(kb):', len(kb)) | |||
| # print() | |||
| X = ["1+1", "0+1", "1+0", "2+0", "1+0+1"] | |||
| Y = [2, 1, 1, 2, 2] | |||