From 3d14196107527bb55f5b91ddff064a036cb067ce Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Sat, 10 Dec 2022 22:11:41 +0800 Subject: [PATCH] Update abducer_base.py --- abducer/abducer_base.py | 118 ++++++++++++++++++++++++++++------------ 1 file changed, 84 insertions(+), 34 deletions(-) diff --git a/abducer/abducer_base.py b/abducer/abducer_base.py index 1e8522d..7b89791 100644 --- a/abducer/abducer_base.py +++ b/abducer/abducer_base.py @@ -26,9 +26,16 @@ import time class AbducerBase(abc.ABC): - def __init__(self, kb, dist_func='confidence', zoopt=False, multiple_predictions=False, cache=True): + def __init__( + self, + kb, + dist_func="confidence", + zoopt=False, + multiple_predictions=False, + cache=True, + ): self.kb = kb - assert dist_func == 'hamming' or dist_func == 'confidence' + assert dist_func == "hamming" or dist_func == "confidence" self.dist_func = dist_func self.zoopt = zoopt self.multiple_predictions = multiple_predictions @@ -39,11 +46,18 @@ class AbducerBase(abc.ABC): self.cache_candidates = {} def _get_cost_list(self, pred_res, pred_res_prob, candidates): - if self.dist_func == 'hamming': + if self.dist_func == "hamming": return hamming_dist(pred_res, candidates) - elif self.dist_func == 'confidence': - mapping = dict(zip(self.kb.pseudo_label_list, list(range(len(self.kb.pseudo_label_list))))) - return confidence_dist(pred_res_prob, [list(map(lambda x: mapping[x], c)) for c in candidates]) + elif self.dist_func == "confidence": + mapping = dict( + zip( + self.kb.pseudo_label_list, + list(range(len(self.kb.pseudo_label_list))), + ) + ) + return confidence_dist( + pred_res_prob, [list(map(lambda x: mapping[x], c)) for c in candidates] + ) def _get_one_candidate(self, pred_res, pred_res_prob, candidates): if len(candidates) == 0: @@ -61,8 +75,12 @@ class AbducerBase(abc.ABC): all_address_flag = reform_idx(solution, pred_res) score = 0 for idx in enumerate(len(pred_res)): - address_idx = [i for i, flag in enumerate(all_address_flag[idx]) if flag != 0] - candidate = self.kb.address_by_idx([pred_res[idx]], key[idx], address_idx, True) + address_idx = [ + i for i, flag in enumerate(all_address_flag[idx]) if flag != 0 + ] + candidate = self.kb.address_by_idx( + [pred_res[idx]], key[idx], address_idx, True + ) if len(candidate) > 0: score += 1 return score @@ -70,7 +88,9 @@ class AbducerBase(abc.ABC): def _zoopt_address_score(self, pred_res, key, sol): if not self.multiple_predictions: address_idx = [idx for idx, i in enumerate(sol.get_x()) if i != 0] - candidates = self.kb.address_by_idx(pred_res, key, address_idx, self.multiple_predictions) + candidates = self.kb.address_by_idx( + pred_res, key, address_idx, self.multiple_predictions + ) return 1 if len(candidates) > 0 else 0 else: return self._zoopt_score_multiple(pred_res, key, sol.get_x()) @@ -98,7 +118,11 @@ class AbducerBase(abc.ABC): pred_res = flatten(pred_res) key = tuple(key) if (tuple(pred_res), key) in self.cache_min_address_num: - address_num = min(max_address_num, self.cache_min_address_num[(tuple(pred_res), key)] + require_more_address) + address_num = min( + max_address_num, + self.cache_min_address_num[(tuple(pred_res), key)] + + require_more_address, + ) if (tuple(pred_res), key, address_num) in self.cache_candidates: candidates = self.cache_candidates[(tuple(pred_res), key, address_num)] if self.zoopt: @@ -127,12 +151,18 @@ class AbducerBase(abc.ABC): if self.zoopt: solution = self.zoopt_get_solution(pred_res, key, max_address_num) address_idx = [idx for idx, i in enumerate(solution) if i != 0] - candidates = self.kb.address_by_idx(pred_res, key, address_idx, self.multiple_predictions) + candidates = self.kb.address_by_idx( + pred_res, key, address_idx, self.multiple_predictions + ) address_num = int(solution.sum()) min_address_num = address_num else: candidates, min_address_num, address_num = self.kb.abduce_candidates( - pred_res, key, max_address_num, require_more_address, self.multiple_predictions + pred_res, + key, + max_address_num, + require_more_address, + self.multiple_predictions, ) candidate = self._get_one_candidate(pred_res, pred_res_prob, candidates) @@ -147,23 +177,31 @@ class AbducerBase(abc.ABC): def batch_abduce(self, Z, Y, max_address_num=-1, require_more_address=0): if self.multiple_predictions: - return self.abduce((Z['cls'], Z['prob'], Y), max_address_num, require_more_address) + return self.abduce( + (Z["cls"], Z["prob"], Y), max_address_num, require_more_address + ) else: - return [self.abduce((z, prob, y), max_address_num, require_more_address) for z, prob, y in zip(Z['cls'], Z['prob'], Y)] + return [ + self.abduce((z, prob, y), max_address_num, require_more_address) + for z, prob, y in zip(Z["cls"], Z["prob"], Y) + ] def __call__(self, Z, Y, max_address_num=-1, require_more_address=0): return self.batch_abduce(Z, Y, max_address_num, require_more_address) - - - -if __name__ == '__main__': - prob1 = [[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]] - prob2 = [[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]] +if __name__ == "__main__": + prob1 = [ + [0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], + [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], + ] + prob2 = [ + [0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], + [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], + ] kb = add_KB() - abd = AbducerBase(kb, 'confidence') + abd = AbducerBase(kb, "confidence") res = abd.abduce(([1, 1], prob1, 8), max_address_num=2, require_more_address=0) print(res) res = abd.abduce(([1, 1], prob2, 8), max_address_num=2, require_more_address=0) @@ -177,7 +215,7 @@ if __name__ == '__main__': print() kb = add_prolog_KB() - abd = AbducerBase(kb, 'confidence') + abd = AbducerBase(kb, "confidence") res = abd.abduce(([1, 1], prob1, 8), max_address_num=2, require_more_address=0) print(res) res = abd.abduce(([1, 1], prob2, 8), max_address_num=2, require_more_address=0) @@ -191,7 +229,7 @@ if __name__ == '__main__': print() kb = add_prolog_KB() - abd = AbducerBase(kb, 'confidence', zoopt=True) + abd = AbducerBase(kb, "confidence", zoopt=True) res = abd.abduce(([1, 1], prob1, 8), max_address_num=2, require_more_address=0) print(res) res = abd.abduce(([1, 1], prob2, 8), max_address_num=2, require_more_address=0) @@ -205,24 +243,38 @@ if __name__ == '__main__': print() kb = HWF_KB(len_list=[1, 3, 5]) - abd = AbducerBase(kb, 'hamming') - res = abd.abduce((['5', '+', '2'], None, 3), max_address_num=2, require_more_address=0) + abd = AbducerBase(kb, "hamming") + res = abd.abduce( + (["5", "+", "2"], None, 3), max_address_num=2, require_more_address=0 + ) print(res) - res = abd.abduce((['5', '+', '2'], None, 64), max_address_num=3, require_more_address=0) + res = abd.abduce( + (["5", "+", "2"], None, 64), max_address_num=3, require_more_address=0 + ) print(res) - res = abd.abduce((['5', '+', '2'], None, 1.67), max_address_num=3, require_more_address=0) + res = abd.abduce( + (["5", "+", "2"], None, 1.67), max_address_num=3, require_more_address=0 + ) print(res) - res = abd.abduce((['5', '8', '8', '8', '8'], None, 3.17), max_address_num=5, require_more_address=3) + res = abd.abduce( + (["5", "8", "8", "8", "8"], None, 3.17), + max_address_num=5, + require_more_address=3, + ) print(res) print() kb = HED_prolog_KB() abd = AbducerBase(kb, zoopt=True, multiple_predictions=True) - consist_exs = [[1, '+', 0, '=', 0], [1, '+', 1, '=', 0], [0, '+', 0, '=', 1, 1]] - consist_exs2 = [[1, '+', 0, '=', 0], [1, '+', 1, '=', 0], [0, '+', 1, '=', 1, 1]] # not consistent with rules - inconsist_exs = [[1, '+', 0, '=', 0], [1, '=', 1, '=', 0], [0, '=', 0, '=', 1, 1]] + consist_exs = [[1, "+", 0, "=", 0], [1, "+", 1, "=", 0], [0, "+", 0, "=", 1, 1]] + consist_exs2 = [ + [1, "+", 0, "=", 0], + [1, "+", 1, "=", 0], + [0, "+", 1, "=", 1, 1], + ] # not consistent with rules + inconsist_exs = [[1, "+", 0, "=", 0], [1, "=", 1, "=", 0], [0, "=", 0, "=", 1, 1]] # inconsist_exs = [[1, '+', 0, '=', 0], ['=', '=', '=', '=', 0], ['=', '=', 0, '=', '=', '=']] - rules = ['my_op([0], [0], [1, 1])', 'my_op([1], [1], [0])', 'my_op([1], [0], [0])'] + rules = ["my_op([0], [0], [1, 1])", "my_op([1], [1], [0])", "my_op([1], [0], [0])"] print(kb.logic_forward(consist_exs), kb.logic_forward(inconsist_exs)) print(kb.consist_rule(consist_exs, rules), kb.consist_rule(consist_exs2, rules)) @@ -236,5 +288,3 @@ if __name__ == '__main__': abduced_rules = abd.abduce_rules(consist_exs) print(abduced_rules) - -