Browse Source

Create kb.py

pull/3/head
troyyyyy GitHub 3 years ago
parent
commit
7f07c781d0
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 18 additions and 10 deletions
  1. +18
    -10
      abducer/kb.py

+ 18
- 10
abducer/kb.py View File

@@ -99,10 +99,11 @@ class add_KB(KBBase):
address_idx_list = list(combinations(list(range(len(pred_res))), address_num))
for address_idx in address_idx_list:
for c in all_address_candidate:
pred_res_array = np.array(pred_res)
pred_res_array[np.array(address_idx)] = c
if(np.count_nonzero(np.array(c) != np.array(pred_res)[np.array(address_idx)]) == address_num and self.logic_forward(pred_res_array) == key):
candidates.append(pred_res_array)
if(np.count_nonzero(np.array(c) != np.array(pred_res)[np.array(address_idx)]) == address_num):
pred_res_array = np.array(pred_res)
pred_res_array[np.array(address_idx)] = c
if(self.logic_forward(pred_res_array) == key):
candidates.append(pred_res_array)
if(len(candidates) > 0):
min_address_num = address_num
@@ -115,10 +116,11 @@ class add_KB(KBBase):
address_idx_list = list(combinations(list(range(len(pred_res))), address_num))
for address_idx in address_idx_list:
for c in all_candidate:
pred_res_array = np.array(pred_res)
pred_res_array[np.array(address_idx)] = c
if(np.count_nonzero(np.array(c) != pred_res_array[np.array(address_idx)]) == address_num and self.logic_forward(pred_res_array) == key):
candidates.append(pred_res_array)
if(np.count_nonzero(np.array(c) != np.array(pred_res)[np.array(address_idx)]) == address_num):
pred_res_array = np.array(pred_res)
pred_res_array[np.array(address_idx)] = c
if(self.logic_forward(pred_res_array) == key):
candidates.append(pred_res_array)

return candidates, min_address_num, address_num
@@ -310,10 +312,16 @@ if __name__ == "__main__":
print(res)
print()
pseudo_label_list = list(range(10)) + ['+', '-', '*', '/']
kb = hwf_KB(pseudo_label_list, max_len = 5)
print('len(kb):', len(kb))
print()
X = ["1+1", "0+1", "1+0", "2+0", "1+0+1"]
Y = [2, 1, 1, 2, 2]
kb = cls_KB(X, Y)
print(len(kb))
print('len(kb):', len(kb))
res = kb.get_candidates(2, 5)
print(res)
res = kb.get_candidates(2, 3)
@@ -325,7 +333,7 @@ if __name__ == "__main__":
X = ["1+1", "0+1", "1+0", "2+0", "1+0.5", "0.75+0.75"]
Y = [2, 1, 1, 2, 1.5, 1.5]
kb = reg_KB(X, Y)
print(len(kb))
print('len(kb):', len(kb))
res = kb.get_candidates(1.6)
print(res)
res = kb.get_candidates(1.6, length = 9)


Loading…
Cancel
Save