Browse Source

Update abducer_base.py

pull/3/head
troyyyyy GitHub 3 years ago
parent
commit
7edd652eea
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 62 additions and 15 deletions
  1. +62
    -15
      abducer/abducer_base.py

+ 62
- 15
abducer/abducer_base.py View File

@@ -14,15 +14,14 @@ import abc
from kb import add_KB
import numpy as np

def hamming_dist(A, B):
return np.sum(np.array(A) != np.array(B))
from itertools import product, combinations

def hamming_dist_kb(A, B):
def hamming_dist(A, B):
B = np.array(B)
A = np.expand_dims(A, axis = 0).repeat(axis=0, repeats=(len(B)))
return np.sum(A != B, axis = 1)

def confidence_dist_kb(A, B):
def confidence_dist(A, B):
B = np.array(B)

#print(A)
@@ -41,7 +40,10 @@ class AbducerBase(abc.ABC):
def __init__(self, kb, dist_func = "hamming", pred_res_parse = None, cache = True):
self.kb = kb
if dist_func == "hamming":
self.dist_func = hamming_dist
dist_func = hamming_dist
elif dist_func == "confidence":
dist_func = confidence_dist
self.dist_func = dist_func
if pred_res_parse is None:
pred_res_parse = lambda x : x["cls"]
self.pred_res_parse = pred_res_parse
@@ -50,6 +52,7 @@ class AbducerBase(abc.ABC):
self.cache_min_address_num = {}
self.cache_candidates = {}


def abduce(self, data, max_address_num = 3, require_more_address = 0, length = -1):
pred_res, ans = data

@@ -62,8 +65,16 @@ class AbducerBase(abc.ABC):
print('cached')
return self.cache_candidates[(tuple(pred_res), ans, address_num)]
candidates, min_address_num, address_num = self.kb.get_abduce_candidates(pred_res, ans, length, self.dist_func, max_address_num, require_more_address)
if(self.kb.base != {}):
all_candidates = self.kb.get_candidates(ans, length)
cost_list = self.dist_func(pred_res, all_candidates)
min_address_num = np.min(cost_list)
address_num = min(max_address_num, min_address_num + require_more_address)
idxs = np.where(cost_list <= address_num)[0]
candidates = [all_candidates[idx] for idx in idxs]
else:
candidates, min_address_num, address_num = self.get_abduce_candidates(pred_res, ans, max_address_num, require_more_address)
if(self.cache):
self.cache_min_address_num[(tuple(pred_res), ans)] = min_address_num
@@ -71,21 +82,57 @@ class AbducerBase(abc.ABC):

return candidates
# candidates = self.kb.get_candidates(ans, length)
# cost_list = self.dist_func(pred_res, candidates)
# address_num = np.min(cost_list)
# # threshold = min(address_num + require_more_address, max_address_num)
# idxs = np.where(cost_list <= address_num + require_more_address)[0]

# return [candidates[idx] for idx in idxs], address_num
# if len(idxs) > 1:
# return None
# return [candidates[idx] for idx in idxs]
def address(self, address_num, pred_res, key):
new_candidates = []
all_address_candidate = list(product(self.kb.pseudo_label_list, repeat = address_num))
address_idx_list = list(combinations(list(range(len(pred_res))), address_num))
for address_idx in address_idx_list:
for c in all_address_candidate:
pred_res_array = np.array(pred_res)
if(np.count_nonzero(np.array(c) != pred_res_array[np.array(address_idx)]) == address_num):
pred_res_array[np.array(address_idx)] = c
if(self.kb.logic_forward(pred_res_array) == key):
new_candidates.append(pred_res_array)
return new_candidates, address_num
def get_abduce_candidates(self, pred_res, key, max_address_num, require_more_address):
candidates = []

for address_num in range(len(pred_res) + 1):
if(address_num > max_address_num):
print('No candidates found')
return None, None, None
if(address_num == 0):
if(self.kb.logic_forward(pred_res) == key):
candidates.append(pred_res)
else:
new_candidates, address_num = self.address(address_num, pred_res, key)
candidates += new_candidates
if(len(candidates) > 0):
min_address_num = address_num
break
for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1):
if(address_num > max_address_num):
return candidates, min_address_num, address_num - 1
new_candidates, address_num = self.address(address_num, pred_res, key)
candidates += new_candidates

return candidates, min_address_num, address_num
def batch_abduce(self, Y, C, max_address_num = 3, require_more_address = 0):
return [
self.abduce((y, c), max_address_num, require_more_address)\
@@ -101,7 +148,7 @@ if __name__ == "__main__":
pseudo_label_list = list(range(10))
kb = add_KB(pseudo_label_list)
abd = AbducerBase(kb)
res = abd.abduce(([1, 1, 1], 4), max_address_num = 2, require_more_address = 0)
res = abd.abduce(([1, 1, 1], 4), max_address_num = 2, require_more_address = 1)
print(res)
print()
res = abd.abduce(([1, 1, 1], 4), max_address_num = 2, require_more_address = 1)


Loading…
Cancel
Save