diff --git a/abducer/abducer_base.py b/abducer/abducer_base.py index f26b363..93810c1 100644 --- a/abducer/abducer_base.py +++ b/abducer/abducer_base.py @@ -17,7 +17,6 @@ import abc from abducer.kb import add_KB, hwf_KB, add_prolog_KB import numpy as np -from itertools import product, combinations import time class AbducerBase(abc.ABC): @@ -26,7 +25,7 @@ class AbducerBase(abc.ABC): assert(dist_func == 'hamming' or dist_func == 'confidence') self.dist_func = dist_func self.cache = cache - + if self.cache: self.cache_min_address_num = {} self.cache_candidates = {} @@ -83,21 +82,23 @@ class AbducerBase(abc.ABC): def abduce(self, data, max_address_num = -1, require_more_address = 0): pred_res, pred_res_prob, ans = data + if max_address_num == -1: + max_address_num = len(pred_res) if self.cache and (tuple(pred_res), ans) in self.cache_min_address_num: address_num = min(max_address_num, self.cache_min_address_num[(tuple(pred_res), ans)] + require_more_address) if (tuple(pred_res), ans, address_num) in self.cache_candidates: candidates = self.cache_candidates[(tuple(pred_res), ans, address_num)] - candidate = self.get_min_cost_candidate(pred_res, pred_res_prob, candidates) - return candidate + return self.get_min_cost_candidate(pred_res, pred_res_prob, candidates) candidates, min_address_num, address_num = self.kb.abduce_candidates(pred_res, ans, max_address_num, require_more_address) + if self.cache: self.cache_min_address_num[(tuple(pred_res), ans)] = min_address_num self.cache_candidates[(tuple(pred_res), ans, address_num)] = candidates - - candidate = self.get_min_cost_candidate(pred_res, pred_res_prob, candidates) + + candidate = self.get_min_cost_candidate(pred_res, pred_res_prob, candidates) return candidate @@ -112,23 +113,34 @@ class AbducerBase(abc.ABC): if __name__ == '__main__': + prob1 = [[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0],[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]] + prob2 = [[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0],[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]] + kb = add_KB(GKB_flag = True) - abd = AbducerBase(kb, 'hamming') - res = abd.abduce(([1, 1], None, 17), max_address_num = 2, require_more_address = 0) + abd = AbducerBase(kb, 'confidence') + res = abd.abduce(([1, 1], prob1, 8), max_address_num = 2, require_more_address = 0) + print(res) + res = abd.abduce(([1, 1], prob2, 8), max_address_num = 2, require_more_address = 0) + print(res) + res = abd.abduce(([1, 1], prob1, 17), max_address_num = 2, require_more_address = 0) print(res) - res = abd.abduce(([1, 1], None, 17), max_address_num = 1, require_more_address = 0) + res = abd.abduce(([1, 1], prob1, 17), max_address_num = 1, require_more_address = 0) print(res) - res = abd.abduce(([1, 1], None, 20), max_address_num = 2, require_more_address = 0) + res = abd.abduce(([1, 1], prob1, 20), max_address_num = 2, require_more_address = 0) print(res) print() kb = add_prolog_KB() - abd = AbducerBase(kb, 'hamming') - res = abd.abduce(([1, 1], None, 17), max_address_num = 2, require_more_address = 0) + abd = AbducerBase(kb, 'confidence') + res = abd.abduce(([1, 1], prob1, 8), max_address_num = 2, require_more_address = 0) + print(res) + res = abd.abduce(([1, 1], prob2, 8), max_address_num = 2, require_more_address = 0) + print(res) + res = abd.abduce(([1, 1], prob1, 17), max_address_num = 2, require_more_address = 0) print(res) - res = abd.abduce(([1, 1], None, 17), max_address_num = 1, require_more_address = 0) + res = abd.abduce(([1, 1], prob1, 17), max_address_num = 1, require_more_address = 0) print(res) - res = abd.abduce(([1, 1], None, 20), max_address_num = 2, require_more_address = 0) + res = abd.abduce(([1, 1], prob1, 20), max_address_num = 2, require_more_address = 0) print(res) print()