diff --git a/abducer/kb.py b/abducer/kb.py index c3ddff0..9c09765 100644 --- a/abducer/kb.py +++ b/abducer/kb.py @@ -99,10 +99,11 @@ class add_KB(KBBase): address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) for address_idx in address_idx_list: for c in all_address_candidate: - pred_res_array = np.array(pred_res) - pred_res_array[np.array(address_idx)] = c - if(np.count_nonzero(np.array(c) != np.array(pred_res)[np.array(address_idx)]) == address_num and self.logic_forward(pred_res_array) == key): - candidates.append(pred_res_array) + if(np.count_nonzero(np.array(c) != np.array(pred_res)[np.array(address_idx)]) == address_num): + pred_res_array = np.array(pred_res) + pred_res_array[np.array(address_idx)] = c + if(self.logic_forward(pred_res_array) == key): + candidates.append(pred_res_array) if(len(candidates) > 0): min_address_num = address_num @@ -115,10 +116,11 @@ class add_KB(KBBase): address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) for address_idx in address_idx_list: for c in all_candidate: - pred_res_array = np.array(pred_res) - pred_res_array[np.array(address_idx)] = c - if(np.count_nonzero(np.array(c) != pred_res_array[np.array(address_idx)]) == address_num and self.logic_forward(pred_res_array) == key): - candidates.append(pred_res_array) + if(np.count_nonzero(np.array(c) != np.array(pred_res)[np.array(address_idx)]) == address_num): + pred_res_array = np.array(pred_res) + pred_res_array[np.array(address_idx)] = c + if(self.logic_forward(pred_res_array) == key): + candidates.append(pred_res_array) return candidates, min_address_num, address_num @@ -310,10 +312,16 @@ if __name__ == "__main__": print(res) print() + pseudo_label_list = list(range(10)) + ['+', '-', '*', '/'] + kb = hwf_KB(pseudo_label_list, max_len = 5) + print('len(kb):', len(kb)) + print() + + X = ["1+1", "0+1", "1+0", "2+0", "1+0+1"] Y = [2, 1, 1, 2, 2] kb = cls_KB(X, Y) - print(len(kb)) + print('len(kb):', len(kb)) res = kb.get_candidates(2, 5) print(res) res = kb.get_candidates(2, 3) @@ -325,7 +333,7 @@ if __name__ == "__main__": X = ["1+1", "0+1", "1+0", "2+0", "1+0.5", "0.75+0.75"] Y = [2, 1, 1, 2, 1.5, 1.5] kb = reg_KB(X, Y) - print(len(kb)) + print('len(kb):', len(kb)) res = kb.get_candidates(1.6) print(res) res = kb.get_candidates(1.6, length = 9)