Browse Source

Update abducer_base.py

pull/3/head
troyyyyy GitHub 3 years ago
parent
commit
20146ee8dd
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 26 additions and 14 deletions
  1. +26
    -14
      abducer/abducer_base.py

+ 26
- 14
abducer/abducer_base.py View File

@@ -17,7 +17,6 @@ import abc
from abducer.kb import add_KB, hwf_KB, add_prolog_KB
import numpy as np

from itertools import product, combinations
import time

class AbducerBase(abc.ABC):
@@ -26,7 +25,7 @@ class AbducerBase(abc.ABC):
assert(dist_func == 'hamming' or dist_func == 'confidence')
self.dist_func = dist_func
self.cache = cache
if self.cache:
self.cache_min_address_num = {}
self.cache_candidates = {}
@@ -83,21 +82,23 @@ class AbducerBase(abc.ABC):

def abduce(self, data, max_address_num = -1, require_more_address = 0):
pred_res, pred_res_prob, ans = data
if max_address_num == -1:
max_address_num = len(pred_res)

if self.cache and (tuple(pred_res), ans) in self.cache_min_address_num:
address_num = min(max_address_num, self.cache_min_address_num[(tuple(pred_res), ans)] + require_more_address)
if (tuple(pred_res), ans, address_num) in self.cache_candidates:
candidates = self.cache_candidates[(tuple(pred_res), ans, address_num)]
candidate = self.get_min_cost_candidate(pred_res, pred_res_prob, candidates)
return candidate
return self.get_min_cost_candidate(pred_res, pred_res_prob, candidates)
candidates, min_address_num, address_num = self.kb.abduce_candidates(pred_res, ans, max_address_num, require_more_address)
if self.cache:
self.cache_min_address_num[(tuple(pred_res), ans)] = min_address_num
self.cache_candidates[(tuple(pred_res), ans, address_num)] = candidates
candidate = self.get_min_cost_candidate(pred_res, pred_res_prob, candidates)
candidate = self.get_min_cost_candidate(pred_res, pred_res_prob, candidates)
return candidate
@@ -112,23 +113,34 @@ class AbducerBase(abc.ABC):

if __name__ == '__main__':
prob1 = [[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0],[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]
prob2 = [[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0],[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]
kb = add_KB(GKB_flag = True)
abd = AbducerBase(kb, 'hamming')
res = abd.abduce(([1, 1], None, 17), max_address_num = 2, require_more_address = 0)
abd = AbducerBase(kb, 'confidence')
res = abd.abduce(([1, 1], prob1, 8), max_address_num = 2, require_more_address = 0)
print(res)
res = abd.abduce(([1, 1], prob2, 8), max_address_num = 2, require_more_address = 0)
print(res)
res = abd.abduce(([1, 1], prob1, 17), max_address_num = 2, require_more_address = 0)
print(res)
res = abd.abduce(([1, 1], None, 17), max_address_num = 1, require_more_address = 0)
res = abd.abduce(([1, 1], prob1, 17), max_address_num = 1, require_more_address = 0)
print(res)
res = abd.abduce(([1, 1], None, 20), max_address_num = 2, require_more_address = 0)
res = abd.abduce(([1, 1], prob1, 20), max_address_num = 2, require_more_address = 0)
print(res)
print()
kb = add_prolog_KB()
abd = AbducerBase(kb, 'hamming')
res = abd.abduce(([1, 1], None, 17), max_address_num = 2, require_more_address = 0)
abd = AbducerBase(kb, 'confidence')
res = abd.abduce(([1, 1], prob1, 8), max_address_num = 2, require_more_address = 0)
print(res)
res = abd.abduce(([1, 1], prob2, 8), max_address_num = 2, require_more_address = 0)
print(res)
res = abd.abduce(([1, 1], prob1, 17), max_address_num = 2, require_more_address = 0)
print(res)
res = abd.abduce(([1, 1], None, 17), max_address_num = 1, require_more_address = 0)
res = abd.abduce(([1, 1], prob1, 17), max_address_num = 1, require_more_address = 0)
print(res)
res = abd.abduce(([1, 1], None, 20), max_address_num = 2, require_more_address = 0)
res = abd.abduce(([1, 1], prob1, 20), max_address_num = 2, require_more_address = 0)
print(res)
print()


Loading…
Cancel
Save