Browse Source

Update kb.py

pull/3/head
troyyyyy GitHub 3 years ago
parent
commit
83dff66d9f
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 111 additions and 12 deletions
  1. +111
    -12
      abducer/kb.py

+ 111
- 12
abducer/kb.py View File

@@ -17,7 +17,7 @@ import numpy as np

from collections import defaultdict

from itertools import product
from itertools import product, combinations

class KBBase(ABC):
def __init__(self):
@@ -80,20 +80,30 @@ class add_KB(KBBase):
def get_all_candidates(self):
return sum([sum(v.values(), []) for v in self.base.values()], [])
def get_abduce_candidates(self, pred_res, key, length, dist_func, max_address_num, require_more_address):
def get_abduce_candidates(self, pred_res, key, max_address_num, require_more_address):
if key is None:
return self.get_all_candidates()
candidates = []
all_candidates = list(product(self.pseudo_label_list, repeat = len(pred_res)))
for address_num in range(length + 1):
for address_num in range(len(pred_res) + 1):
if(address_num > max_address_num):
print('No candidates found')
return None, None, None
for c in all_candidates:
if(dist_func(c, pred_res) == address_num):
if(self.logic_forward(c) == key):
candidates.append(c)
if(address_num == 0):
if(self.logic_forward(pred_res) == key):
candidates.append(pred_res)
else:
all_address_candidate = list(product(self.pseudo_label_list, repeat = address_num))
address_idx_list = list(combinations(list(range(len(pred_res))), address_num))
for address_idx in address_idx_list:
for c in all_address_candidate:
pred_res_array = np.array(pred_res)
pred_res_array[np.array(address_idx)] = c
if(np.count_nonzero(np.array(c) != np.array(pred_res)[np.array(address_idx)]) == address_num and self.logic_forward(pred_res_array) == key):
candidates.append(pred_res_array)
if(len(candidates) > 0):
min_address_num = address_num
break
@@ -101,10 +111,14 @@ class add_KB(KBBase):
for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1):
if(address_num > max_address_num):
return candidates, min_address_num, address_num - 1
for c in all_candidates:
if(dist_func(c, pred_res) == address_num):
if(self.logic_forward(c) == key):
candidates.append(c)
all_candidate = list(product(self.pseudo_label_list, repeat = address_num))
address_idx_list = list(combinations(list(range(len(pred_res))), address_num))
for address_idx in address_idx_list:
for c in all_candidate:
pred_res_array = np.array(pred_res)
pred_res_array[np.array(address_idx)] = c
if(np.count_nonzero(np.array(c) != pred_res_array[np.array(address_idx)]) == address_num and self.logic_forward(pred_res_array) == key):
candidates.append(pred_res_array)

return candidates, min_address_num, address_num
@@ -114,6 +128,91 @@ class add_KB(KBBase):

def __len__(self):
return sum(self._dict_len(v) for v in self.base.values())
# class hwf_KB(KBBase):
# def __init__(self, pseudo_label_list, max_len = 5):
# super().__init__()
# self.pseudo_label_list = pseudo_label_list
# self.base = {}
# X = self.get_X(self.pseudo_label_list, max_len)
# Y = self.get_Y(X, self.logic_forward)

# for x, y in zip(X, Y):
# self.base.setdefault(len(x), defaultdict(list))[y].append(np.array(x))
# def logic_forward(self, nums):
# return sum(nums)
# def get_X(self, pseudo_label_list, max_len):
# res = []
# assert(max_len >= 2)
# for len in range(2, max_len + 1):
# res += list(product(pseudo_label_list, repeat = len))
# return res

# def get_Y(self, X, logic_forward):
# return [logic_forward(nums) for nums in X]

# def get_candidates(self, key, length = None):
# if key is None:
# return self.get_all_candidates()

# length = self._length(length)
# return sum([self.base[l][key] for l in length], [])
# def get_all_candidates(self):
# return sum([sum(v.values(), []) for v in self.base.values()], [])
# def get_abduce_candidates(self, pred_res, key, length, dist_func, max_address_num, require_more_address):
# if key is None:
# return self.get_all_candidates()
# candidates = []
# # all_candidates = list(product(self.pseudo_label_list, repeat = len(pred_res)))

# for address_num in range(length + 1):
# if(address_num > max_address_num):
# print('No candidates found')
# return None, None, None
# if(address_num == 0):
# if(self.logic_forward(pred_res) == key):
# candidates.append(pred_res)
# else:
# all_address_candidate = list(product(self.pseudo_label_list, repeat = address_num))
# address_idx_list = list(combinations(list(range(len(pred_res))), address_num))
# for address_idx in address_idx_list:
# for c in all_address_candidate:
# pred_res_array = np.array(pred_res)
# pred_res_array[np.array(address_idx)] = c
# if(np.count_nonzero(np.array(c) != np.array(pred_res)[np.array(address_idx)]) == address_num and self.logic_forward(pred_res_array) == key):
# candidates.append(pred_res_array)
# if(len(candidates) > 0):
# min_address_num = address_num
# break
# for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1):
# if(address_num > max_address_num):
# return candidates, min_address_num, address_num - 1
# all_candidate = list(product(self.pseudo_label_list, repeat = address_num))
# address_idx_list = list(combinations(list(range(len(pred_res))), address_num))
# for address_idx in address_idx_list:
# for c in all_candidate:
# pred_res_array = np.array(pred_res)
# pred_res_array[np.array(address_idx)] = c
# if(np.count_nonzero(np.array(c) != pred_res_array[np.array(address_idx)]) == address_num and self.logic_forward(pred_res_array) == key):
# candidates.append(pred_res_array)

# return candidates, min_address_num, address_num

# def _dict_len(self, dic):
# return sum(len(c) for c in dic.values())

# def __len__(self):
# return sum(self._dict_len(v) for v in self.base.values())

class cls_KB(KBBase):
def __init__(self, X, Y = None):


Loading…
Cancel
Save