Browse Source

Update kb.py

pull/3/head
troyyyyy GitHub 3 years ago
parent
commit
5beb068ee1
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 46 additions and 110 deletions
  1. +46
    -110
      abducer/kb.py

+ 46
- 110
abducer/kb.py View File

@@ -16,8 +16,7 @@ import copy
import numpy as np

from collections import defaultdict

from itertools import product, combinations
from itertools import product

class KBBase(ABC):
def __init__(self):
@@ -46,16 +45,17 @@ class KBBase(ABC):
pass

class add_KB(KBBase):
def __init__(self, pseudo_label_list, max_len = 5):
def __init__(self, pseudo_label_list, kb_max_len = -1):
super().__init__()
self.pseudo_label_list = pseudo_label_list
self.base = {}
X = self.get_X(self.pseudo_label_list, max_len)
Y = self.get_Y(X, self.logic_forward)
self.kb_max_len = kb_max_len
if(self.kb_max_len > 0):
X = self.get_X(self.pseudo_label_list, self.kb_max_len)
Y = self.get_Y(X, self.logic_forward)

for x, y in zip(X, Y):
self.base.setdefault(len(x), defaultdict(list))[y].append(np.array(x))
for x, y in zip(X, Y):
self.base.setdefault(len(x), defaultdict(list))[y].append(np.array(x))
def logic_forward(self, nums):
return sum(nums)
@@ -71,60 +71,20 @@ class add_KB(KBBase):
return [logic_forward(nums) for nums in X]

def get_candidates(self, key, length = None):
if(self.base == {}):
return []
if key is None:
return self.get_all_candidates()

length = self._length(length)
if(self.kb_max_len < min(length)):
return []
return sum([self.base[l][key] for l in length], [])
def get_all_candidates(self):
return sum([sum(v.values(), []) for v in self.base.values()], [])
def get_abduce_candidates(self, pred_res, key, max_address_num, require_more_address):
if key is None:
return self.get_all_candidates()
candidates = []

for address_num in range(len(pred_res) + 1):
if(address_num > max_address_num):
print('No candidates found')
return None, None, None
if(address_num == 0):
if(self.logic_forward(pred_res) == key):
candidates.append(pred_res)
else:
all_address_candidate = list(product(self.pseudo_label_list, repeat = address_num))
address_idx_list = list(combinations(list(range(len(pred_res))), address_num))
for address_idx in address_idx_list:
for c in all_address_candidate:
pred_res_array = np.array(pred_res)
if(np.count_nonzero(np.array(c) != pred_res_array[np.array(address_idx)]) == address_num):
pred_res_array[np.array(address_idx)] = c
if(self.logic_forward(pred_res_array) == key):
candidates.append(pred_res_array)
if(len(candidates) > 0):
min_address_num = address_num
break
for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1):
if(address_num > max_address_num):
return candidates, min_address_num, address_num - 1
all_candidate = list(product(self.pseudo_label_list, repeat = address_num))
address_idx_list = list(combinations(list(range(len(pred_res))), address_num))
for address_idx in address_idx_list:
for c in all_candidate:
pred_res_array = np.array(pred_res)
if(np.count_nonzero(np.array(c) != pred_res_array[np.array(address_idx)]) == address_num):
pred_res_array[np.array(address_idx)] = c
if(self.logic_forward(pred_res_array) == key):
candidates.append(pred_res_array)

return candidates, min_address_num, address_num

def _dict_len(self, dic):
return sum(len(c) for c in dic.values())

@@ -132,16 +92,17 @@ class add_KB(KBBase):
return sum(self._dict_len(v) for v in self.base.values())
# class hwf_KB(KBBase):
# def __init__(self, pseudo_label_list, max_len = 5):
# def __init__(self, pseudo_label_list, kb_max_len = -1):
# super().__init__()
# self.pseudo_label_list = pseudo_label_list
# self.base = {}
# X = self.get_X(self.pseudo_label_list, max_len)
# Y = self.get_Y(X, self.logic_forward)
# self.kb_max_len = kb_max_len
# if(self.kb_max_len > 0):
# X = self.get_X(self.pseudo_label_list, self.kb_max_len)
# Y = self.get_Y(X, self.logic_forward)

# for x, y in zip(X, Y):
# self.base.setdefault(len(x), defaultdict(list))[y].append(np.array(x))
# for x, y in zip(X, Y):
# self.base.setdefault(len(x), defaultdict(list))[y].append(np.array(x))
# def logic_forward(self, nums):
# return sum(nums)
@@ -157,65 +118,26 @@ class add_KB(KBBase):
# return [logic_forward(nums) for nums in X]

# def get_candidates(self, key, length = None):
# if(self.base == {}):
# return []
# if key is None:
# return self.get_all_candidates()

# length = self._length(length)
# if(self.kb_max_len < min(length)):
# return []
# return sum([self.base[l][key] for l in length], [])
# def get_all_candidates(self):
# return sum([sum(v.values(), []) for v in self.base.values()], [])
# def get_abduce_candidates(self, pred_res, key, length, dist_func, max_address_num, require_more_address):
# if key is None:
# return self.get_all_candidates()
# candidates = []
# # all_candidates = list(product(self.pseudo_label_list, repeat = len(pred_res)))

# for address_num in range(length + 1):
# if(address_num > max_address_num):
# print('No candidates found')
# return None, None, None
# if(address_num == 0):
# if(self.logic_forward(pred_res) == key):
# candidates.append(pred_res)
# else:
# all_address_candidate = list(product(self.pseudo_label_list, repeat = address_num))
# address_idx_list = list(combinations(list(range(len(pred_res))), address_num))
# for address_idx in address_idx_list:
# for c in all_address_candidate:
# pred_res_array = np.array(pred_res)
# pred_res_array[np.array(address_idx)] = c
# if(np.count_nonzero(np.array(c) != np.array(pred_res)[np.array(address_idx)]) == address_num and self.logic_forward(pred_res_array) == key):
# candidates.append(pred_res_array)
# if(len(candidates) > 0):
# min_address_num = address_num
# break
# for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1):
# if(address_num > max_address_num):
# return candidates, min_address_num, address_num - 1
# all_candidate = list(product(self.pseudo_label_list, repeat = address_num))
# address_idx_list = list(combinations(list(range(len(pred_res))), address_num))
# for address_idx in address_idx_list:
# for c in all_candidate:
# pred_res_array = np.array(pred_res)
# pred_res_array[np.array(address_idx)] = c
# if(np.count_nonzero(np.array(c) != pred_res_array[np.array(address_idx)]) == address_num and self.logic_forward(pred_res_array) == key):
# candidates.append(pred_res_array)

# return candidates, min_address_num, address_num

# def _dict_len(self, dic):
# return sum(len(c) for c in dic.values())

# def __len__(self):
# return sum(self._dict_len(v) for v in self.base.values())
class cls_KB(KBBase):
def __init__(self, X, Y = None):
super().__init__()
@@ -298,25 +220,39 @@ class reg_KB(KBBase):
return sum([sum(len(x) for x in D[0]) for D in self.base.values()])

if __name__ == "__main__":
# With ground KB
pseudo_label_list = list(range(10))
kb = add_KB(pseudo_label_list, max_len = 5)
kb = add_KB(pseudo_label_list, kb_max_len = 5)
print('len(kb):', len(kb))
print()
res = kb.get_candidates(0)
print(res)
print()
res = kb.get_candidates(18, length = 2)
print(res)
print()
res = kb.get_candidates(18, length = 8)
print(res)
res = kb.get_candidates(7, length = 3)
print(res)
print()
pseudo_label_list = list(range(10)) + ['+', '-', '*', '/']
kb = hwf_KB(pseudo_label_list, max_len = 5)
# Without ground KB
pseudo_label_list = list(range(10))
kb = add_KB(pseudo_label_list)
print('len(kb):', len(kb))
res = kb.get_candidates(0)
print(res)
res = kb.get_candidates(18, length = 2)
print(res)
res = kb.get_candidates(18, length = 8)
print(res)
res = kb.get_candidates(7, length = 3)
print(res)
print()
# pseudo_label_list = list(range(10)) + ['+', '-', '*', '/']
# kb = hwf_KB(pseudo_label_list, max_len = 5)
# print('len(kb):', len(kb))
# print()
X = ["1+1", "0+1", "1+0", "2+0", "1+0+1"]
Y = [2, 1, 1, 2, 2]


Loading…
Cancel
Save