Browse Source

Update abducer_base.py

pull/3/head
troyyyyy GitHub 3 years ago
parent
commit
d544ccd3c1
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 30 additions and 25 deletions
  1. +30
    -25
      abducer/abducer_base.py

+ 30
- 25
abducer/abducer_base.py View File

@@ -11,16 +11,12 @@
#================================================================#

import abc
# from kb import add_KB, hwf_KB
from abducer.kb import add_KB, hwf_KB
from kb import add_KB, hwf_KB
# from abducer.kb import add_KB, hwf_KB
import numpy as np

from itertools import product, combinations





class AbducerBase(abc.ABC):
def __init__(self, kb, dist_func = 'confidence', cache = True):
self.kb = kb
@@ -57,10 +53,15 @@ class AbducerBase(abc.ABC):
return self.confidence_dist(pred_res_prob, candidates)

def get_min_cost_candidate(self, pred_res, pred_res_prob, candidates):
cost_list = self.get_cost_list(pred_res, pred_res_prob, candidates)
min_address_num = np.min(cost_list)
idxs = np.where(cost_list == min_address_num)[0]
return [candidates[idx] for idx in idxs][0]
if(len(candidates) == 0):
return []
elif(len(candidates) == 1):
return candidates[0]
else:
cost_list = self.get_cost_list(pred_res, pred_res_prob, candidates)
min_address_num = np.min(cost_list)
idxs = np.where(cost_list == min_address_num)[0]
return [candidates[idx] for idx in idxs][0]

def abduce(self, data, max_address_num = -1, require_more_address = 0):
pred_res, pred_res_prob, ans = data
@@ -75,13 +76,16 @@ class AbducerBase(abc.ABC):
candidate = self.get_min_cost_candidate(pred_res, pred_res_prob, candidates)
return candidate
if(self.kb.base != {}):
if self.kb.GKB_flag:
all_candidates = self.kb.get_candidates(ans, len(pred_res))
cost_list = self.hamming_dist(pred_res, all_candidates)
min_address_num = np.min(cost_list)
address_num = min(max_address_num, min_address_num + require_more_address)
idxs = np.where(cost_list <= address_num)[0]
candidates = [all_candidates[idx] for idx in idxs]
if(len(all_candidates) == 0):
return []
else:
cost_list = self.hamming_dist(pred_res, all_candidates)
min_address_num = np.min(cost_list)
address_num = min(max_address_num, min_address_num + require_more_address)
idxs = np.where(cost_list <= address_num)[0]
candidates = [all_candidates[idx] for idx in idxs]
else:
candidates, min_address_num, address_num = self.get_abduce_candidates(pred_res, ans, max_address_num, require_more_address)
@@ -111,8 +115,7 @@ class AbducerBase(abc.ABC):
candidates = []
for address_num in range(len(pred_res) + 1):
if(address_num > max_address_num):
print('No candidates found')
return None, None, None
return [], None, None
if(address_num == 0):
if(abs(self.kb.logic_forward(pred_res) - key) <= 1e-3):
@@ -146,24 +149,26 @@ class AbducerBase(abc.ABC):


if __name__ == '__main__':
kb = add_KB()
kb = add_KB(GKB_flag = True)
abd = AbducerBase(kb, 'hamming')
res = abd.abduce(([1, 1, 1], None, 4), max_address_num = 2, require_more_address = 0)
res = abd.abduce(([1, 1], None, 4), max_address_num = 2, require_more_address = 0)
print(res)
res = abd.abduce(([1, 1, 1], None, 4), max_address_num = 2, require_more_address = 1)
res = abd.abduce(([1, 1], None, 4), max_address_num = 2, require_more_address = 1)
print(res)
res = abd.abduce(([1, 1, 1], None, 4), max_address_num = 1, require_more_address = 1)
res = abd.abduce(([1, 1], None, 4), max_address_num = 1, require_more_address = 1)
print(res)
res = abd.abduce(([1, 1, 1], None, 4), max_address_num = 2, require_more_address = 0)
res = abd.abduce(([1, 1], None, 4), max_address_num = 2, require_more_address = 0)
print(res)
res = abd.abduce(([1, 1, 1], None, 5), max_address_num = 2, require_more_address = 1)
res = abd.abduce(([1, 1], None, 5), max_address_num = 2, require_more_address = 1)
print(res)
print()
kb = hwf_KB()
abd = AbducerBase(kb)
abd = AbducerBase(kb, 'hamming')
res = abd.abduce((['5', '+', '2'], None, 3), max_address_num = 2, require_more_address = 0)
print(res)
res = abd.abduce((['5', '+', '2'], None, 3.09), max_address_num = 2, require_more_address = 0)
print(res)
res = abd.abduce((['5', '+', '2'], None, 1.67), max_address_num = 3, require_more_address = 0)
print(res)
res = abd.abduce((['5', '+', '3'], None, 0.33), max_address_num = 3, require_more_address = 3)


Loading…
Cancel
Save