Browse Source

Update abducer_base.py

pull/3/head
troyyyyy GitHub 3 years ago
parent
commit
15e4427702
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 34 additions and 68 deletions
  1. +34
    -68
      abducer/abducer_base.py

+ 34
- 68
abducer/abducer_base.py View File

@@ -14,7 +14,7 @@ import sys
sys.path.append("..")

import abc
from abducer.kb import add_KB, hwf_KB
from abducer.kb import add_KB, hwf_KB, add_prolog_KB
import numpy as np

from itertools import product, combinations
@@ -67,35 +67,31 @@ class AbducerBase(abc.ABC):
min_address_num = np.min(cost_list)
idxs = np.where(cost_list == min_address_num)[0]
return [candidates[idx] for idx in idxs][0]
def filter_all_candidates(self, pred_res, all_candidates, max_address_num, require_more_address):
if len(all_candidates) == 0:
candidates = []
min_address_num = 0
address_num = 0
else:
cost_list = self.hamming_dist(pred_res, all_candidates)
min_address_num = np.min(cost_list)
address_num = min(max_address_num, min_address_num + require_more_address)
idxs = np.where(cost_list <= address_num)[0]
candidates = [all_candidates[idx] for idx in idxs]
return candidates, min_address_num, address_num

def abduce(self, data, max_address_num = -1, require_more_address = 0):
pred_res, pred_res_prob, ans = data
if max_address_num == -1:
max_address_num = len(pred_res)

if self.cache and (tuple(pred_res), ans) in self.cache_min_address_num:
address_num = min(max_address_num, self.cache_min_address_num[(tuple(pred_res), ans)] + require_more_address)
if (tuple(pred_res), ans, address_num) in self.cache_candidates:
# print('cached')
candidates = self.cache_candidates[(tuple(pred_res), ans, address_num)]
candidate = self.get_min_cost_candidate(pred_res, pred_res_prob, candidates)
return candidate
if self.kb.GKB_flag:
all_candidates = self.kb.get_candidates(ans, len(pred_res))
if len(all_candidates) == 0:
candidates = []
min_address_num = 0
address_num = 0
else:
cost_list = self.hamming_dist(pred_res, all_candidates)
min_address_num = np.min(cost_list)
address_num = min(max_address_num, min_address_num + require_more_address)
idxs = np.where(cost_list <= address_num)[0]
candidates = [all_candidates[idx] for idx in idxs]
else:
candidates, min_address_num, address_num = self.get_abduce_candidates(pred_res, ans, max_address_num, require_more_address)
candidates, min_address_num, address_num = self.kb.abduce_candidates(pred_res, ans, max_address_num, require_more_address)
if self.cache:
self.cache_min_address_num[(tuple(pred_res), ans)] = min_address_num
@@ -104,47 +100,6 @@ class AbducerBase(abc.ABC):
candidate = self.get_min_cost_candidate(pred_res, pred_res_prob, candidates)
return candidate
def address(self, address_num, pred_res, key):
new_candidates = []
all_address_candidate = list(product(self.kb.pseudo_label_list, repeat = address_num))
address_idx_list = list(combinations(list(range(len(pred_res))), address_num))
for address_idx in address_idx_list:
for c in all_address_candidate:
address_list = [pred_res[i] for i in address_idx]
if(sum([address_list[i] == c[i] for i in range(address_num)]) == 0):
candidate = pred_res.copy()
for i, idx in enumerate(address_idx):
candidate[idx] = c[i]
if self.kb.logic_forward(candidate) == key:
new_candidates.append(candidate)
return new_candidates
def get_abduce_candidates(self, pred_res, key, max_address_num, require_more_address):
candidates = []
print(pred_res)
for address_num in range(len(pred_res) + 1):
if address_num == 0:
if abs(self.kb.logic_forward(pred_res) - key) <= 1e-3:
candidates.append(pred_res)
else:
new_candidates = self.address(address_num, pred_res, key)
candidates += new_candidates
if len(candidates) > 0:
min_address_num = address_num
break
if address_num >= max_address_num:
return [], 0, 0
for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1):
if address_num > max_address_num:
return candidates, min_address_num, address_num - 1
new_candidates = self.address(address_num, pred_res, key)
candidates += new_candidates

return candidates, min_address_num, address_num
def batch_abduce(self, Z, Y, max_address_num = -1, require_more_address = 0):
return [
@@ -154,21 +109,30 @@ class AbducerBase(abc.ABC):

def __call__(self, Z, Y, max_address_num = -1, require_more_address = 0):
return self.batch_abduce(Z, Y, max_address_num, require_more_address)



if __name__ == '__main__':
kb = add_KB(GKB_flag = True)
abd = AbducerBase(kb, 'hamming')
res = abd.abduce(([1, 1], None, 4), max_address_num = 2, require_more_address = 0)
res = abd.abduce(([1, 1], None, 17), max_address_num = 2, require_more_address = 0)
print(res)
res = abd.abduce(([1, 1], None, 4), max_address_num = 2, require_more_address = 1)
res = abd.abduce(([1, 1], None, 17), max_address_num = 1, require_more_address = 0)
print(res)
res = abd.abduce(([1, 1], None, 5), max_address_num = 2, require_more_address = 1)
res = abd.abduce(([1, 1], None, 20), max_address_num = 2, require_more_address = 0)
print(res)
print()
kb = hwf_KB()
kb = add_prolog_KB()
abd = AbducerBase(kb, 'hamming')
res = abd.abduce(([1, 1], None, 17), max_address_num = 2, require_more_address = 0)
print(res)
res = abd.abduce(([1, 1], None, 17), max_address_num = 1, require_more_address = 0)
print(res)
res = abd.abduce(([1, 1], None, 20), max_address_num = 2, require_more_address = 0)
print(res)
print()
kb = hwf_KB(len_list = [1, 3, 5])
abd = AbducerBase(kb, 'hamming')
res = abd.abduce((['5', '+', '2'], None, 3), max_address_num = 2, require_more_address = 0)
print(res)
@@ -176,6 +140,8 @@ if __name__ == '__main__':
print(res)
res = abd.abduce((['5', '+', '2'], None, 1.67), max_address_num = 3, require_more_address = 0)
print(res)
res = abd.abduce((['5', '+', '3'], None, 0.33), max_address_num = 3, require_more_address = 3)
res = abd.abduce((['5', '8', '8', '8', '8'], None, 3.17), max_address_num = 5, require_more_address = 3)
print(res)
print()

Loading…
Cancel
Save