From 15e4427702179edb463e65fcedabd00e484986cb Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Thu, 24 Nov 2022 17:42:43 +0800 Subject: [PATCH] Update abducer_base.py --- abducer/abducer_base.py | 102 ++++++++++++++-------------------------- 1 file changed, 34 insertions(+), 68 deletions(-) diff --git a/abducer/abducer_base.py b/abducer/abducer_base.py index 63a3c73..f26b363 100644 --- a/abducer/abducer_base.py +++ b/abducer/abducer_base.py @@ -14,7 +14,7 @@ import sys sys.path.append("..") import abc -from abducer.kb import add_KB, hwf_KB +from abducer.kb import add_KB, hwf_KB, add_prolog_KB import numpy as np from itertools import product, combinations @@ -67,35 +67,31 @@ class AbducerBase(abc.ABC): min_address_num = np.min(cost_list) idxs = np.where(cost_list == min_address_num)[0] return [candidates[idx] for idx in idxs][0] + + def filter_all_candidates(self, pred_res, all_candidates, max_address_num, require_more_address): + if len(all_candidates) == 0: + candidates = [] + min_address_num = 0 + address_num = 0 + else: + cost_list = self.hamming_dist(pred_res, all_candidates) + min_address_num = np.min(cost_list) + address_num = min(max_address_num, min_address_num + require_more_address) + idxs = np.where(cost_list <= address_num)[0] + candidates = [all_candidates[idx] for idx in idxs] + return candidates, min_address_num, address_num def abduce(self, data, max_address_num = -1, require_more_address = 0): pred_res, pred_res_prob, ans = data - if max_address_num == -1: - max_address_num = len(pred_res) if self.cache and (tuple(pred_res), ans) in self.cache_min_address_num: address_num = min(max_address_num, self.cache_min_address_num[(tuple(pred_res), ans)] + require_more_address) if (tuple(pred_res), ans, address_num) in self.cache_candidates: - # print('cached') candidates = self.cache_candidates[(tuple(pred_res), ans, address_num)] candidate = self.get_min_cost_candidate(pred_res, pred_res_prob, candidates) return candidate - - if self.kb.GKB_flag: - all_candidates = self.kb.get_candidates(ans, len(pred_res)) - if len(all_candidates) == 0: - candidates = [] - min_address_num = 0 - address_num = 0 - else: - cost_list = self.hamming_dist(pred_res, all_candidates) - min_address_num = np.min(cost_list) - address_num = min(max_address_num, min_address_num + require_more_address) - idxs = np.where(cost_list <= address_num)[0] - candidates = [all_candidates[idx] for idx in idxs] - - else: - candidates, min_address_num, address_num = self.get_abduce_candidates(pred_res, ans, max_address_num, require_more_address) + + candidates, min_address_num, address_num = self.kb.abduce_candidates(pred_res, ans, max_address_num, require_more_address) if self.cache: self.cache_min_address_num[(tuple(pred_res), ans)] = min_address_num @@ -104,47 +100,6 @@ class AbducerBase(abc.ABC): candidate = self.get_min_cost_candidate(pred_res, pred_res_prob, candidates) return candidate - def address(self, address_num, pred_res, key): - new_candidates = [] - all_address_candidate = list(product(self.kb.pseudo_label_list, repeat = address_num)) - address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) - for address_idx in address_idx_list: - for c in all_address_candidate: - address_list = [pred_res[i] for i in address_idx] - if(sum([address_list[i] == c[i] for i in range(address_num)]) == 0): - candidate = pred_res.copy() - for i, idx in enumerate(address_idx): - candidate[idx] = c[i] - if self.kb.logic_forward(candidate) == key: - new_candidates.append(candidate) - return new_candidates - - def get_abduce_candidates(self, pred_res, key, max_address_num, require_more_address): - candidates = [] - print(pred_res) - for address_num in range(len(pred_res) + 1): - if address_num == 0: - if abs(self.kb.logic_forward(pred_res) - key) <= 1e-3: - candidates.append(pred_res) - else: - new_candidates = self.address(address_num, pred_res, key) - candidates += new_candidates - - if len(candidates) > 0: - min_address_num = address_num - break - - if address_num >= max_address_num: - return [], 0, 0 - - for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1): - if address_num > max_address_num: - return candidates, min_address_num, address_num - 1 - new_candidates = self.address(address_num, pred_res, key) - candidates += new_candidates - - return candidates, min_address_num, address_num - def batch_abduce(self, Z, Y, max_address_num = -1, require_more_address = 0): return [ @@ -154,21 +109,30 @@ class AbducerBase(abc.ABC): def __call__(self, Z, Y, max_address_num = -1, require_more_address = 0): return self.batch_abduce(Z, Y, max_address_num, require_more_address) - - + if __name__ == '__main__': kb = add_KB(GKB_flag = True) abd = AbducerBase(kb, 'hamming') - res = abd.abduce(([1, 1], None, 4), max_address_num = 2, require_more_address = 0) + res = abd.abduce(([1, 1], None, 17), max_address_num = 2, require_more_address = 0) print(res) - res = abd.abduce(([1, 1], None, 4), max_address_num = 2, require_more_address = 1) + res = abd.abduce(([1, 1], None, 17), max_address_num = 1, require_more_address = 0) print(res) - res = abd.abduce(([1, 1], None, 5), max_address_num = 2, require_more_address = 1) + res = abd.abduce(([1, 1], None, 20), max_address_num = 2, require_more_address = 0) print(res) print() - kb = hwf_KB() + kb = add_prolog_KB() + abd = AbducerBase(kb, 'hamming') + res = abd.abduce(([1, 1], None, 17), max_address_num = 2, require_more_address = 0) + print(res) + res = abd.abduce(([1, 1], None, 17), max_address_num = 1, require_more_address = 0) + print(res) + res = abd.abduce(([1, 1], None, 20), max_address_num = 2, require_more_address = 0) + print(res) + print() + + kb = hwf_KB(len_list = [1, 3, 5]) abd = AbducerBase(kb, 'hamming') res = abd.abduce((['5', '+', '2'], None, 3), max_address_num = 2, require_more_address = 0) print(res) @@ -176,6 +140,8 @@ if __name__ == '__main__': print(res) res = abd.abduce((['5', '+', '2'], None, 1.67), max_address_num = 3, require_more_address = 0) print(res) - res = abd.abduce((['5', '+', '3'], None, 0.33), max_address_num = 3, require_more_address = 3) + res = abd.abduce((['5', '8', '8', '8', '8'], None, 3.17), max_address_num = 5, require_more_address = 3) print(res) print() + +