| @@ -32,23 +32,32 @@ class KBBase(ABC): | |||
| def abduce_candidates(self): | |||
| pass | |||
| def filter_all_candidates(self, pred_res, all_candidates, max_address_num, require_more_address): | |||
| if len(all_candidates) == 0: | |||
| candidates = [] | |||
| min_address_num = 0 | |||
| address_num = 0 | |||
| else: | |||
| cost_list = self.hamming_dist(pred_res, all_candidates) | |||
| min_address_num = np.min(cost_list) | |||
| address_num = min(max_address_num, min_address_num + require_more_address) | |||
| idxs = np.where(cost_list <= address_num)[0] | |||
| candidates = [all_candidates[idx] for idx in idxs] | |||
| def abduction(self, pred_res, key, max_address_num, require_more_address): | |||
| candidates = [] | |||
| for address_num in range(len(pred_res) + 1): | |||
| if address_num == 0: | |||
| if abs(self.logic_forward(pred_res) - key) <= 1e-3: | |||
| candidates.append(pred_res) | |||
| else: | |||
| new_candidates = self.address(address_num, pred_res, key) | |||
| candidates += new_candidates | |||
| if len(candidates) > 0: | |||
| min_address_num = address_num | |||
| break | |||
| if address_num >= max_address_num: | |||
| return [], 0, 0 | |||
| for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1): | |||
| if address_num > max_address_num: | |||
| return candidates, min_address_num, address_num - 1 | |||
| new_candidates = self.address(address_num, pred_res, key) | |||
| candidates += new_candidates | |||
| return candidates, min_address_num, address_num | |||
| def hamming_dist(self, A, B): | |||
| B = np.array(B) | |||
| A = np.expand_dims(A, axis = 0).repeat(axis=0, repeats=(len(B))) | |||
| return np.sum(A != B, axis = 1) | |||
| def __len__(self): | |||
| pass | |||
| @@ -90,66 +99,38 @@ class ClsKB(KBBase): | |||
| pass | |||
| def abduce_candidates(self, pred_res, key, max_address_num = -1, require_more_address = 0): | |||
| if max_address_num == -1: | |||
| max_address_num = len(pred_res) | |||
| if self.GKB_flag: | |||
| all_candidates = self.get_candidates_GKB(key, len(pred_res)) | |||
| return self.filter_all_candidates(pred_res, all_candidates, max_address_num, require_more_address) | |||
| return self.abduce_from_GKB(pred_res, key, max_address_num, require_more_address) | |||
| else: | |||
| return self.abduction(pred_res, key, max_address_num, require_more_address) | |||
| def get_candidates_GKB(self, key, length = None): | |||
| if self.base == {}: | |||
| def abduce_from_GKB(self, pred_res, key, max_address_num, require_more_address): | |||
| if self.base == {} or len(pred_res) not in self.len_list: | |||
| return [] | |||
| if key is None: | |||
| return self.get_all_candidates() | |||
| all_candidates = self.base[len(pred_res)][key] | |||
| if length is None: | |||
| length = list(self.base.keys()) | |||
| elif type(length) is int and length not in self.len_list: | |||
| return [] | |||
| if len(all_candidates) == 0: | |||
| candidates = [] | |||
| min_address_num = 0 | |||
| address_num = 0 | |||
| else: | |||
| length = [length] | |||
| return sum([self.base[l][key] for l in length], []) | |||
| cost_list = self.hamming_dist(pred_res, all_candidates) | |||
| min_address_num = np.min(cost_list) | |||
| address_num = min(max_address_num, min_address_num + require_more_address) | |||
| idxs = np.where(cost_list <= address_num)[0] | |||
| candidates = [all_candidates[idx] for idx in idxs] | |||
| return candidates, min_address_num, address_num | |||
| def get_all_candidates(self): | |||
| if self.base == {}: | |||
| return [] | |||
| else: | |||
| return sum([sum(v.values(), []) for v in self.base.values()], []) | |||
| def hamming_dist(self, A, B): | |||
| B = np.array(B) | |||
| A = np.expand_dims(A, axis = 0).repeat(axis=0, repeats=(len(B))) | |||
| return np.sum(A != B, axis = 1) | |||
| def abduction(self, pred_res, key, max_address_num, require_more_address): | |||
| candidates = [] | |||
| for address_num in range(len(pred_res) + 1): | |||
| if address_num == 0: | |||
| if abs(self.logic_forward(pred_res) - key) <= 1e-3: | |||
| candidates.append(pred_res) | |||
| else: | |||
| new_candidates = self.address(address_num, pred_res, key) | |||
| candidates += new_candidates | |||
| if len(candidates) > 0: | |||
| min_address_num = address_num | |||
| break | |||
| if address_num >= max_address_num: | |||
| return [], 0, 0 | |||
| for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1): | |||
| if address_num > max_address_num: | |||
| return candidates, min_address_num, address_num - 1 | |||
| new_candidates = self.address(address_num, pred_res, key) | |||
| candidates += new_candidates | |||
| return candidates, min_address_num, address_num | |||
| def address(self, address_num, pred_res, key): | |||
| new_candidates = [] | |||
| @@ -214,9 +195,6 @@ class hwf_KB(ClsKB): | |||
| return round(eval(''.join(formula)), 2) | |||
| class prolog_KB(KBBase): | |||
| def __init__(self, pseudo_label_list): | |||
| super().__init__() | |||
| @@ -224,16 +202,29 @@ class prolog_KB(KBBase): | |||
| self.prolog = pyswip.Prolog() | |||
| for i in self.pseudo_label_list: | |||
| self.prolog.assertz("pseudo_label(%s)" % i) | |||
| self.prolog_flag = True | |||
| def logic_forward(self): | |||
| pass | |||
| def abduce_candidates(self, pred_res, key, max_address_num = -1, require_more_address = 0): | |||
| if max_address_num == -1: | |||
| max_address_num = len(pred_res) | |||
| all_candidates = self.get_candidates_prolog(key) | |||
| return self.filter_all_candidates(pred_res, all_candidates, max_address_num, require_more_address) | |||
| def abduce_candidates(self, pred_res, key, max_address_num, require_more_address): | |||
| return self.abduction(pred_res, key, max_address_num, require_more_address) | |||
| def address(self, address_num, pred_res, key): | |||
| new_candidates = [] | |||
| address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) | |||
| for address_idx in address_idx_list: | |||
| query_string = "addition(" | |||
| for idx, i in enumerate(pred_res): | |||
| tmp = 'Z' + str(idx) + ',' if idx in address_idx else str(i) + ',' | |||
| query_string += tmp | |||
| query_string += "%s)." | |||
| abduce_c = [list(z.values()) for z in list(self.prolog.query(query_string % key))] | |||
| for c in abduce_c: | |||
| candidate = pred_res.copy() | |||
| for i, idx in enumerate(address_idx): | |||
| candidate[idx] = c[i] | |||
| new_candidates.append(candidate) | |||
| return new_candidates | |||
| class add_prolog_KB(prolog_KB): | |||
| @@ -301,42 +292,7 @@ class RegKB(KBBase): | |||
| import time | |||
| if __name__ == "__main__": | |||
| # With ground KB | |||
| kb = add_KB(GKB_flag = True) | |||
| print('len(kb):', len(kb)) | |||
| res = kb.get_candidates_GKB(0) | |||
| print(res) | |||
| res = kb.get_candidates_GKB(18) | |||
| print(res) | |||
| res = kb.get_candidates_GKB(18) | |||
| print(res) | |||
| res = kb.get_candidates_GKB(16) | |||
| print(res) | |||
| print() | |||
| # Without ground KB | |||
| kb = add_KB() | |||
| print('len(kb):', len(kb)) | |||
| print() | |||
| # Prolog | |||
| kb = add_prolog_KB() | |||
| print(kb.logic_forward([3, 4])) | |||
| res = kb.get_candidates_prolog(16) | |||
| print(res) | |||
| start = time.time() | |||
| kb = hwf_KB(GKB_flag = True, len_list = [1, 3, 5]) | |||
| print(time.time() - start) | |||
| print('len(kb):', len(kb)) | |||
| res = kb.get_candidates_GKB(2, length = 1) | |||
| print(res) | |||
| res = kb.get_candidates_GKB(1, length = 3) | |||
| print(res) | |||
| res = kb.get_candidates_GKB(3.67, length = 5) | |||
| print(res) | |||
| print() | |||
| pass | |||
| # X = ["1+1", "0+1", "1+0", "2+0", "1+0+1"] | |||