| @@ -99,10 +99,11 @@ class add_KB(KBBase): | |||
| address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) | |||
| for address_idx in address_idx_list: | |||
| for c in all_address_candidate: | |||
| pred_res_array = np.array(pred_res) | |||
| pred_res_array[np.array(address_idx)] = c | |||
| if(np.count_nonzero(np.array(c) != np.array(pred_res)[np.array(address_idx)]) == address_num and self.logic_forward(pred_res_array) == key): | |||
| candidates.append(pred_res_array) | |||
| if(np.count_nonzero(np.array(c) != np.array(pred_res)[np.array(address_idx)]) == address_num): | |||
| pred_res_array = np.array(pred_res) | |||
| pred_res_array[np.array(address_idx)] = c | |||
| if(self.logic_forward(pred_res_array) == key): | |||
| candidates.append(pred_res_array) | |||
| if(len(candidates) > 0): | |||
| min_address_num = address_num | |||
| @@ -115,10 +116,11 @@ class add_KB(KBBase): | |||
| address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) | |||
| for address_idx in address_idx_list: | |||
| for c in all_candidate: | |||
| pred_res_array = np.array(pred_res) | |||
| pred_res_array[np.array(address_idx)] = c | |||
| if(np.count_nonzero(np.array(c) != pred_res_array[np.array(address_idx)]) == address_num and self.logic_forward(pred_res_array) == key): | |||
| candidates.append(pred_res_array) | |||
| if(np.count_nonzero(np.array(c) != np.array(pred_res)[np.array(address_idx)]) == address_num): | |||
| pred_res_array = np.array(pred_res) | |||
| pred_res_array[np.array(address_idx)] = c | |||
| if(self.logic_forward(pred_res_array) == key): | |||
| candidates.append(pred_res_array) | |||
| return candidates, min_address_num, address_num | |||
| @@ -310,10 +312,16 @@ if __name__ == "__main__": | |||
| print(res) | |||
| print() | |||
| pseudo_label_list = list(range(10)) + ['+', '-', '*', '/'] | |||
| kb = hwf_KB(pseudo_label_list, max_len = 5) | |||
| print('len(kb):', len(kb)) | |||
| print() | |||
| X = ["1+1", "0+1", "1+0", "2+0", "1+0+1"] | |||
| Y = [2, 1, 1, 2, 2] | |||
| kb = cls_KB(X, Y) | |||
| print(len(kb)) | |||
| print('len(kb):', len(kb)) | |||
| res = kb.get_candidates(2, 5) | |||
| print(res) | |||
| res = kb.get_candidates(2, 3) | |||
| @@ -325,7 +333,7 @@ if __name__ == "__main__": | |||
| X = ["1+1", "0+1", "1+0", "2+0", "1+0.5", "0.75+0.75"] | |||
| Y = [2, 1, 1, 2, 1.5, 1.5] | |||
| kb = reg_KB(X, Y) | |||
| print(len(kb)) | |||
| print('len(kb):', len(kb)) | |||
| res = kb.get_candidates(1.6) | |||
| print(res) | |||
| res = kb.get_candidates(1.6, length = 9) | |||